16 template <
typename ElementType,
typename DMLCParserDType>
17 inline static std::unique_ptr<treelite::CSRDMatrix> CreateFromParserImpl(
18 const char* filename,
const char* format,
int nthread,
int verbose) {
19 std::unique_ptr<dmlc::Parser<uint32_t, DMLCParserDType>> parser(
20 dmlc::Parser<uint32_t, DMLCParserDType>::Create(filename, 0, 1, format));
22 const int max_thread = omp_get_max_threads();
23 nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
25 std::vector<ElementType> data;
26 std::vector<uint32_t> col_ind;
27 std::vector<size_t> row_ptr;
33 std::vector<size_t> max_col_ind(nthread, 0);
34 parser->BeforeFirst();
35 while (parser->Next()) {
36 const dmlc::RowBlock<uint32_t, DMLCParserDType>& batch = parser->Value();
37 num_row += batch.size;
38 num_elem += batch.offset[batch.size];
39 const size_t top = data.size();
40 data.resize(top + batch.offset[batch.size] - batch.offset[0]);
41 col_ind.resize(top + batch.offset[batch.size] - batch.offset[0]);
42 CHECK_LT(static_cast<int64_t>(batch.offset[batch.size]),
43 std::numeric_limits<int64_t>::max());
44 #pragma omp parallel for schedule(static) num_threads(nthread) 45 for (int64_t i = static_cast<int64_t>(batch.offset[0]);
46 i < static_cast<int64_t>(batch.offset[batch.size]); ++i) {
47 const int tid = omp_get_thread_num();
48 const uint32_t index = batch.index[i];
49 const ElementType fvalue
50 = ((batch.value ==
nullptr) ? static_cast<ElementType>(1)
51 :
static_cast<ElementType
>(batch.value[i]));
52 const size_t offset = top + i - batch.offset[0];
53 data[offset] = fvalue;
54 col_ind[offset] = index;
55 max_col_ind[tid] = std::max(max_col_ind[tid], static_cast<size_t>(index));
57 const size_t rtop = row_ptr.size();
58 row_ptr.resize(rtop + batch.size);
59 CHECK_LT(static_cast<int64_t>(batch.size), std::numeric_limits<int64_t>::max());
60 #pragma omp parallel for schedule(static) num_threads(nthread) 61 for (int64_t i = 0; i < static_cast<int64_t>(batch.size); ++i) {
62 row_ptr[rtop + i] = row_ptr[rtop - 1] + batch.offset[i + 1] - batch.offset[0];
65 LOG(INFO) << num_row <<
" rows read into memory";
68 num_col = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1;
69 return treelite::CSRDMatrix::Create(std::move(data), std::move(col_ind), std::move(row_ptr),
73 std::unique_ptr<treelite::CSRDMatrix>
75 const char* filename,
const char* format,
treelite::TypeInfo dtype,
int nthread,
int verbose) {
77 case treelite::TypeInfo::kFloat32:
78 return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
79 case treelite::TypeInfo::kFloat64:
80 return CreateFromParserImpl<double, float>(filename, format, nthread, verbose);
81 case treelite::TypeInfo::kUInt32:
82 return CreateFromParserImpl<uint32_t, int64_t>(filename, format, nthread, verbose);
86 return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
94 template<
typename ElementType>
95 std::unique_ptr<DenseDMatrix>
97 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col) {
98 std::unique_ptr<DenseDMatrix> matrix = std::make_unique<DenseDMatrixImpl<ElementType>>(
99 std::move(data), missing_value, num_row, num_col);
100 matrix->element_type_ = TypeToInfo<ElementType>();
104 template<
typename ElementType>
105 std::unique_ptr<DenseDMatrix>
106 DenseDMatrix::Create(
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col) {
107 auto* data_ptr =
static_cast<const ElementType*
>(data);
108 const size_t num_elem = num_row * num_col;
109 return DenseDMatrix::Create(std::vector<ElementType>(data_ptr, data_ptr + num_elem),
110 *static_cast<const ElementType*>(missing_value), num_row, num_col);
113 std::unique_ptr<DenseDMatrix>
114 DenseDMatrix::Create(
115 TypeInfo type,
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col) {
116 CHECK(type != TypeInfo::kInvalid) <<
"ElementType cannot be invalid";
118 case TypeInfo::kFloat32:
119 return Create<float>(data, missing_value, num_row, num_col);
120 case TypeInfo::kFloat64:
121 return Create<double>(data, missing_value, num_row, num_col);
122 case TypeInfo::kInvalid:
123 case TypeInfo::kUInt32:
127 return std::unique_ptr<DenseDMatrix>(
nullptr);
131 DenseDMatrix::GetElementType()
const {
132 return element_type_;
135 template<
typename ElementType>
136 DenseDMatrixImpl<ElementType>::DenseDMatrixImpl(
137 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col)
138 : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row),
141 template<
typename ElementType>
143 DenseDMatrixImpl<ElementType>::GetNumRow()
const {
147 template<
typename ElementType>
149 DenseDMatrixImpl<ElementType>::GetNumCol()
const {
153 template<
typename ElementType>
155 DenseDMatrixImpl<ElementType>::GetNumElem()
const {
156 return num_row * num_col;
159 template<
typename ElementType>
161 DenseDMatrixImpl<ElementType>::GetType()
const {
162 return DMatrixType::kDense;
165 template <
typename ElementType>
166 template <
typename OutputType>
168 DenseDMatrixImpl<ElementType>::FillRow(
size_t row_id, OutputType* out)
const {
170 size_t in_idx = row_id * num_col;
171 while (out_idx < num_col) {
172 out[out_idx] =
static_cast<OutputType
>(data[in_idx]);
178 template <
typename ElementType>
179 template <
typename OutputType>
181 DenseDMatrixImpl<ElementType>::ClearRow(
size_t row_id, OutputType* out)
const {
182 for (
size_t i = 0; i < num_col; ++i) {
183 out[i] = std::numeric_limits<OutputType>::quiet_NaN();
187 template<
typename ElementType>
188 std::unique_ptr<CSRDMatrix>
189 CSRDMatrix::Create(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
190 std::vector<size_t> row_ptr,
size_t num_row,
size_t num_col) {
191 std::unique_ptr<CSRDMatrix> matrix = std::make_unique<CSRDMatrixImpl<ElementType>>(
192 std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col);
193 matrix->element_type_ = TypeToInfo<ElementType>();
197 template<
typename ElementType>
198 std::unique_ptr<CSRDMatrix>
199 CSRDMatrix::Create(
const void* data,
const uint32_t* col_ind,
200 const size_t* row_ptr,
size_t num_row,
size_t num_col) {
201 auto* data_ptr =
static_cast<const ElementType*
>(data);
202 const size_t num_elem = row_ptr[num_row];
203 return CSRDMatrix::Create(
204 std::vector<ElementType>(data_ptr, data_ptr + num_elem),
205 std::vector<uint32_t>(col_ind, col_ind + num_elem),
206 std::vector<size_t>(row_ptr, row_ptr + num_row + 1),
211 std::unique_ptr<CSRDMatrix>
212 CSRDMatrix::Create(TypeInfo type,
const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
213 size_t num_row,
size_t num_col) {
214 CHECK(type != TypeInfo::kInvalid) <<
"ElementType cannot be invalid";
216 case TypeInfo::kFloat32:
217 return Create<float>(data, col_ind, row_ptr, num_row, num_col);
218 case TypeInfo::kFloat64:
219 return Create<double>(data, col_ind, row_ptr, num_row, num_col);
220 case TypeInfo::kInvalid:
221 case TypeInfo::kUInt32:
225 return std::unique_ptr<CSRDMatrix>(
nullptr);
228 std::unique_ptr<CSRDMatrix>
230 const char* filename,
const char* format,
const char* data_type,
int nthread,
int verbose) {
232 return CreateFromParser(filename, format, dtype, nthread, verbose);
236 CSRDMatrix::GetElementType()
const {
237 return element_type_;
240 template <
typename ElementType>
241 CSRDMatrixImpl<ElementType>::CSRDMatrixImpl(
242 std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
243 size_t num_row,
size_t num_col)
244 : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)),
245 num_row(num_row), num_col(num_col)
248 template <
typename ElementType>
250 CSRDMatrixImpl<ElementType>::GetNumRow()
const {
254 template <
typename ElementType>
256 CSRDMatrixImpl<ElementType>::GetNumCol()
const {
260 template <
typename ElementType>
262 CSRDMatrixImpl<ElementType>::GetNumElem()
const {
263 return row_ptr.at(num_row);
266 template <
typename ElementType>
268 CSRDMatrixImpl<ElementType>::GetType()
const {
269 return DMatrixType::kSparseCSR;
272 template <
typename ElementType>
273 template <
typename OutputType>
275 CSRDMatrixImpl<ElementType>::FillRow(
size_t row_id, OutputType* out)
const {
276 for (
size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
277 out[col_ind[i]] =
static_cast<OutputType
>(data[i]);
281 template <
typename ElementType>
282 template <
typename OutputType>
284 CSRDMatrixImpl<ElementType>::ClearRow(
size_t row_id, OutputType* out)
const {
285 for (
size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
286 out[col_ind[i]] = std::numeric_limits<OutputType>::quiet_NaN();
290 template class DenseDMatrixImpl<float>;
291 template class DenseDMatrixImpl<double>;
292 template class CSRDMatrixImpl<float>;
293 template class CSRDMatrixImpl<double>;
295 template void CSRDMatrixImpl<float>::FillRow<
float>(size_t,
float*)
const;
296 template void CSRDMatrixImpl<float>::FillRow<
double>(size_t,
double*)
const;
297 template void CSRDMatrixImpl<float>::ClearRow<
float>(size_t,
float*)
const;
298 template void CSRDMatrixImpl<float>::ClearRow<
double>(size_t,
double*)
const;
299 template void CSRDMatrixImpl<double>::FillRow<
float>(size_t,
float*)
const;
300 template void CSRDMatrixImpl<double>::FillRow<
double>(size_t,
double*)
const;
301 template void CSRDMatrixImpl<double>::ClearRow<
float>(size_t,
float*)
const;
302 template void CSRDMatrixImpl<double>::ClearRow<
double>(size_t,
double*)
const;
303 template void DenseDMatrixImpl<float>::FillRow<
float>(size_t,
float*)
const;
304 template void DenseDMatrixImpl<float>::FillRow<
double>(size_t,
double*)
const;
305 template void DenseDMatrixImpl<float>::ClearRow<
float>(size_t,
float*)
const;
306 template void DenseDMatrixImpl<float>::ClearRow<
double>(size_t,
double*)
const;
307 template void DenseDMatrixImpl<double>::FillRow<
float>(size_t,
float*)
const;
308 template void DenseDMatrixImpl<double>::FillRow<
double>(size_t,
double*)
const;
309 template void DenseDMatrixImpl<double>::ClearRow<
float>(size_t,
float*)
const;
310 template void DenseDMatrixImpl<double>::ClearRow<
double>(size_t,
double*)
const;
Input data structure of Treelite.
TypeInfo
Types used by thresholds and leaf outputs.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
compatiblity wrapper for systems that don't support OpenMP