16 template <
typename ElementType,
typename DMLCParserDType>
17 inline static std::unique_ptr<treelite::CSRDMatrix> CreateFromParserImpl(
18 const char* filename,
const char* format,
int nthread,
int verbose) {
19 std::unique_ptr<dmlc::Parser<uint32_t, DMLCParserDType>> parser(
20 dmlc::Parser<uint32_t, DMLCParserDType>::Create(filename, 0, 1, format));
22 const int max_thread = omp_get_max_threads();
23 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
25 std::vector<ElementType> data;
26 std::vector<uint32_t> col_ind;
27 std::vector<size_t> row_ptr;
33 std::vector<size_t> max_col_ind(nthread, 0);
34 parser->BeforeFirst();
35 while (parser->Next()) {
36 const dmlc::RowBlock<uint32_t, DMLCParserDType>& batch = parser->Value();
37 num_row += batch.size;
38 num_elem += batch.offset[batch.size];
39 const size_t top = data.size();
40 data.resize(top + batch.offset[batch.size] - batch.offset[0]);
41 col_ind.resize(top + batch.offset[batch.size] - batch.offset[0]);
42 CHECK_LT(static_cast<int64_t>(batch.offset[batch.size]),
43 std::numeric_limits<int64_t>::max());
44 #pragma omp parallel for schedule(static) num_threads(nthread) 45 for (int64_t i = static_cast<int64_t>(batch.offset[0]);
46 i < static_cast<int64_t>(batch.offset[batch.size]); ++i) {
47 const int tid = omp_get_thread_num();
48 const uint32_t index = batch.index[i];
49 const ElementType fvalue
50 = ((batch.value ==
nullptr) ? static_cast<ElementType>(1)
51 :
static_cast<ElementType
>(batch.value[i]));
52 const size_t offset = top + i - batch.offset[0];
53 data[offset] = fvalue;
54 col_ind[offset] = index;
55 max_col_ind[tid] = std::max(max_col_ind[tid], static_cast<size_t>(index));
57 const size_t rtop = row_ptr.size();
58 row_ptr.resize(rtop + batch.size);
59 CHECK_LT(static_cast<int64_t>(batch.size), std::numeric_limits<int64_t>::max());
60 #pragma omp parallel for schedule(static) num_threads(nthread) 61 for (int64_t i = 0; i < static_cast<int64_t>(batch.size); ++i) {
62 row_ptr[rtop + i] = row_ptr[rtop - 1] + batch.offset[i + 1] - batch.offset[0];
65 LOG(INFO) << num_row <<
" rows read into memory";
68 num_col = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1;
69 return treelite::CSRDMatrix::Create(std::move(data), std::move(col_ind), std::move(row_ptr),
73 std::unique_ptr<treelite::CSRDMatrix>
75 const char* filename,
const char* format,
treelite::TypeInfo dtype,
int nthread,
int verbose) {
77 case treelite::TypeInfo::kFloat32:
78 return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
79 case treelite::TypeInfo::kFloat64:
80 return CreateFromParserImpl<double, float>(filename, format, nthread, verbose);
81 case treelite::TypeInfo::kUInt32:
82 return CreateFromParserImpl<uint32_t, int64_t>(filename, format, nthread, verbose);
86 return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
94 template<
typename ElementType>
95 std::unique_ptr<DenseDMatrix>
97 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col) {
98 std::unique_ptr<DenseDMatrix> matrix = std::make_unique<DenseDMatrixImpl<ElementType>>(
99 std::move(data), missing_value, num_row, num_col);
100 matrix->element_type_ = TypeToInfo<ElementType>();
104 template<
typename ElementType>
105 std::unique_ptr<DenseDMatrix>
106 DenseDMatrix::Create(
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col) {
107 auto* data_ptr =
static_cast<const ElementType*
>(data);
108 const size_t num_elem = num_row * num_col;
109 return DenseDMatrix::Create(std::vector<ElementType>(data_ptr, data_ptr + num_elem),
110 *static_cast<const ElementType*>(missing_value), num_row, num_col);
113 std::unique_ptr<DenseDMatrix>
114 DenseDMatrix::Create(
115 TypeInfo type,
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col) {
116 CHECK(type != TypeInfo::kInvalid) <<
"ElementType cannot be invalid";
118 case TypeInfo::kFloat32:
119 return Create<float>(data, missing_value, num_row, num_col);
120 case TypeInfo::kFloat64:
121 return Create<double>(data, missing_value, num_row, num_col);
122 case TypeInfo::kInvalid:
123 case TypeInfo::kUInt32:
127 return std::unique_ptr<DenseDMatrix>(
nullptr);
131 DenseDMatrix::GetElementType()
const {
132 return element_type_;
135 template<
typename ElementType>
136 DenseDMatrixImpl<ElementType>::DenseDMatrixImpl(
137 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col)
138 : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row),
141 template<
typename ElementType>
143 DenseDMatrixImpl<ElementType>::GetNumRow()
const {
147 template<
typename ElementType>
149 DenseDMatrixImpl<ElementType>::GetNumCol()
const {
153 template<
typename ElementType>
155 DenseDMatrixImpl<ElementType>::GetNumElem()
const {
156 return num_row * num_col;
159 template<
typename ElementType>
161 DenseDMatrixImpl<ElementType>::GetType()
const {
162 return DMatrixType::kDense;
165 template<
typename ElementType>
166 std::unique_ptr<CSRDMatrix>
167 CSRDMatrix::Create(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
168 std::vector<size_t> row_ptr,
size_t num_row,
size_t num_col) {
169 std::unique_ptr<CSRDMatrix> matrix = std::make_unique<CSRDMatrixImpl<ElementType>>(
170 std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col);
171 matrix->element_type_ = TypeToInfo<ElementType>();
175 template<
typename ElementType>
176 std::unique_ptr<CSRDMatrix>
177 CSRDMatrix::Create(
const void* data,
const uint32_t* col_ind,
178 const size_t* row_ptr,
size_t num_row,
size_t num_col) {
179 auto* data_ptr =
static_cast<const ElementType*
>(data);
180 const size_t num_elem = row_ptr[num_row];
181 return CSRDMatrix::Create(
182 std::vector<ElementType>(data_ptr, data_ptr + num_elem),
183 std::vector<uint32_t>(col_ind, col_ind + num_elem),
184 std::vector<size_t>(row_ptr, row_ptr + num_row + 1),
189 std::unique_ptr<CSRDMatrix>
190 CSRDMatrix::Create(TypeInfo type,
const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
191 size_t num_row,
size_t num_col) {
192 CHECK(type != TypeInfo::kInvalid) <<
"ElementType cannot be invalid";
194 case TypeInfo::kFloat32:
195 return Create<float>(data, col_ind, row_ptr, num_row, num_col);
196 case TypeInfo::kFloat64:
197 return Create<double>(data, col_ind, row_ptr, num_row, num_col);
198 case TypeInfo::kInvalid:
199 case TypeInfo::kUInt32:
203 return std::unique_ptr<CSRDMatrix>(
nullptr);
206 std::unique_ptr<CSRDMatrix>
208 const char* filename,
const char* format,
const char* data_type,
int nthread,
int verbose) {
210 return CreateFromParser(filename, format, dtype, nthread, verbose);
214 CSRDMatrix::GetElementType()
const {
215 return element_type_;
218 template <
typename ElementType>
219 CSRDMatrixImpl<ElementType>::CSRDMatrixImpl(
220 std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
221 size_t num_row,
size_t num_col)
222 : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)),
223 num_row(num_row), num_col(num_col)
226 template <
typename ElementType>
228 CSRDMatrixImpl<ElementType>::GetNumRow()
const {
232 template <
typename ElementType>
234 CSRDMatrixImpl<ElementType>::GetNumCol()
const {
238 template <
typename ElementType>
240 CSRDMatrixImpl<ElementType>::GetNumElem()
const {
241 return row_ptr.at(num_row);
244 template <
typename ElementType>
246 CSRDMatrixImpl<ElementType>::GetType()
const {
247 return DMatrixType::kSparseCSR;
250 template class DenseDMatrixImpl<float>;
251 template class DenseDMatrixImpl<double>;
252 template class CSRDMatrixImpl<float>;
253 template class CSRDMatrixImpl<double>;
Input data structure of Treelite.
TypeInfo
Types used by thresholds and leaf outputs.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
compatiblity wrapper for systems that don't support OpenMP