Treelite
data.cc
Go to the documentation of this file.
1 
8 #include <treelite/data.h>
9 #include <treelite/omp.h>
10 #include <memory>
11 #include <limits>
12 #include <cstdint>
13 
14 namespace {
15 
16 template <typename ElementType, typename DMLCParserDType>
17 inline static std::unique_ptr<treelite::CSRDMatrix> CreateFromParserImpl(
18  const char* filename, const char* format, int nthread, int verbose) {
19  std::unique_ptr<dmlc::Parser<uint32_t, DMLCParserDType>> parser(
20  dmlc::Parser<uint32_t, DMLCParserDType>::Create(filename, 0, 1, format));
21 
22  const int max_thread = omp_get_max_threads();
23  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
24 
25  std::vector<ElementType> data;
26  std::vector<uint32_t> col_ind;
27  std::vector<size_t> row_ptr;
28  row_ptr.resize(1, 0);
29  size_t num_row = 0;
30  size_t num_col = 0;
31  size_t num_elem = 0;
32 
33  std::vector<size_t> max_col_ind(nthread, 0);
34  parser->BeforeFirst();
35  while (parser->Next()) {
36  const dmlc::RowBlock<uint32_t, DMLCParserDType>& batch = parser->Value();
37  num_row += batch.size;
38  num_elem += batch.offset[batch.size];
39  const size_t top = data.size();
40  data.resize(top + batch.offset[batch.size] - batch.offset[0]);
41  col_ind.resize(top + batch.offset[batch.size] - batch.offset[0]);
42  CHECK_LT(static_cast<int64_t>(batch.offset[batch.size]),
43  std::numeric_limits<int64_t>::max());
44  #pragma omp parallel for schedule(static) num_threads(nthread)
45  for (int64_t i = static_cast<int64_t>(batch.offset[0]);
46  i < static_cast<int64_t>(batch.offset[batch.size]); ++i) {
47  const int tid = omp_get_thread_num();
48  const uint32_t index = batch.index[i];
49  const ElementType fvalue
50  = ((batch.value == nullptr) ? static_cast<ElementType>(1)
51  : static_cast<ElementType>(batch.value[i]));
52  const size_t offset = top + i - batch.offset[0];
53  data[offset] = fvalue;
54  col_ind[offset] = index;
55  max_col_ind[tid] = std::max(max_col_ind[tid], static_cast<size_t>(index));
56  }
57  const size_t rtop = row_ptr.size();
58  row_ptr.resize(rtop + batch.size);
59  CHECK_LT(static_cast<int64_t>(batch.size), std::numeric_limits<int64_t>::max());
60  #pragma omp parallel for schedule(static) num_threads(nthread)
61  for (int64_t i = 0; i < static_cast<int64_t>(batch.size); ++i) {
62  row_ptr[rtop + i] = row_ptr[rtop - 1] + batch.offset[i + 1] - batch.offset[0];
63  }
64  if (verbose > 0) {
65  LOG(INFO) << num_row << " rows read into memory";
66  }
67  }
68  num_col = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1;
69  return treelite::CSRDMatrix::Create(std::move(data), std::move(col_ind), std::move(row_ptr),
70  num_row, num_col);
71 }
72 
73 std::unique_ptr<treelite::CSRDMatrix>
74 CreateFromParser(
75  const char* filename, const char* format, treelite::TypeInfo dtype, int nthread, int verbose) {
76  switch (dtype) {
77  case treelite::TypeInfo::kFloat32:
78  return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
79  case treelite::TypeInfo::kFloat64:
80  return CreateFromParserImpl<double, float>(filename, format, nthread, verbose);
81  case treelite::TypeInfo::kUInt32:
82  return CreateFromParserImpl<uint32_t, int64_t>(filename, format, nthread, verbose);
83  default:
84  LOG(FATAL) << "Unrecognized TypeInfo: " << treelite::TypeInfoToString(dtype);
85  }
86  return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
87  // avoid missing value warning
88 }
89 
90 } // anonymous namespace
91 
92 namespace treelite {
93 
94 template<typename ElementType>
95 std::unique_ptr<DenseDMatrix>
96 DenseDMatrix::Create(
97  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col) {
98  std::unique_ptr<DenseDMatrix> matrix = std::make_unique<DenseDMatrixImpl<ElementType>>(
99  std::move(data), missing_value, num_row, num_col);
100  matrix->element_type_ = TypeToInfo<ElementType>();
101  return matrix;
102 }
103 
104 template<typename ElementType>
105 std::unique_ptr<DenseDMatrix>
106 DenseDMatrix::Create(const void* data, const void* missing_value, size_t num_row, size_t num_col) {
107  auto* data_ptr = static_cast<const ElementType*>(data);
108  const size_t num_elem = num_row * num_col;
109  return DenseDMatrix::Create(std::vector<ElementType>(data_ptr, data_ptr + num_elem),
110  *static_cast<const ElementType*>(missing_value), num_row, num_col);
111 }
112 
113 std::unique_ptr<DenseDMatrix>
114 DenseDMatrix::Create(
115  TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col) {
116  CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid";
117  switch (type) {
118  case TypeInfo::kFloat32:
119  return Create<float>(data, missing_value, num_row, num_col);
120  case TypeInfo::kFloat64:
121  return Create<double>(data, missing_value, num_row, num_col);
122  case TypeInfo::kInvalid:
123  case TypeInfo::kUInt32:
124  default:
125  LOG(FATAL) << "Invalid type for DenseDMatrix: " << TypeInfoToString(type);
126  }
127  return std::unique_ptr<DenseDMatrix>(nullptr);
128 }
129 
130 TypeInfo
131 DenseDMatrix::GetElementType() const {
132  return element_type_;
133 }
134 
135 template<typename ElementType>
136 DenseDMatrixImpl<ElementType>::DenseDMatrixImpl(
137  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col)
138  : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row),
139  num_col(num_col) {}
140 
141 template<typename ElementType>
142 size_t
143 DenseDMatrixImpl<ElementType>::GetNumRow() const {
144  return num_row;
145 }
146 
147 template<typename ElementType>
148 size_t
149 DenseDMatrixImpl<ElementType>::GetNumCol() const {
150  return num_col;
151 }
152 
153 template<typename ElementType>
154 size_t
155 DenseDMatrixImpl<ElementType>::GetNumElem() const {
156  return num_row * num_col;
157 }
158 
159 template<typename ElementType>
160 DMatrixType
161 DenseDMatrixImpl<ElementType>::GetType() const {
162  return DMatrixType::kDense;
163 }
164 
165 template <typename ElementType>
166 template <typename OutputType>
167 void
168 DenseDMatrixImpl<ElementType>::FillRow(size_t row_id, OutputType* out) const {
169  size_t out_idx = 0;
170  size_t in_idx = row_id * num_col;
171  while (out_idx < num_col) {
172  out[out_idx] = static_cast<OutputType>(data[in_idx]);
173  ++out_idx;
174  ++in_idx;
175  }
176 }
177 
178 template <typename ElementType>
179 template <typename OutputType>
180 void
181 DenseDMatrixImpl<ElementType>::ClearRow(size_t row_id, OutputType* out) const {
182  for (size_t i = 0; i < num_col; ++i) {
183  out[i] = std::numeric_limits<OutputType>::quiet_NaN();
184  }
185 }
186 
187 template<typename ElementType>
188 std::unique_ptr<CSRDMatrix>
189 CSRDMatrix::Create(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
190  std::vector<size_t> row_ptr, size_t num_row, size_t num_col) {
191  std::unique_ptr<CSRDMatrix> matrix = std::make_unique<CSRDMatrixImpl<ElementType>>(
192  std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col);
193  matrix->element_type_ = TypeToInfo<ElementType>();
194  return matrix;
195 }
196 
197 template<typename ElementType>
198 std::unique_ptr<CSRDMatrix>
199 CSRDMatrix::Create(const void* data, const uint32_t* col_ind,
200  const size_t* row_ptr, size_t num_row, size_t num_col) {
201  auto* data_ptr = static_cast<const ElementType*>(data);
202  const size_t num_elem = row_ptr[num_row];
203  return CSRDMatrix::Create(
204  std::vector<ElementType>(data_ptr, data_ptr + num_elem),
205  std::vector<uint32_t>(col_ind, col_ind + num_elem),
206  std::vector<size_t>(row_ptr, row_ptr + num_row + 1),
207  num_row,
208  num_col);
209 }
210 
211 std::unique_ptr<CSRDMatrix>
212 CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr,
213  size_t num_row, size_t num_col) {
214  CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid";
215  switch (type) {
216  case TypeInfo::kFloat32:
217  return Create<float>(data, col_ind, row_ptr, num_row, num_col);
218  case TypeInfo::kFloat64:
219  return Create<double>(data, col_ind, row_ptr, num_row, num_col);
220  case TypeInfo::kInvalid:
221  case TypeInfo::kUInt32:
222  default:
223  LOG(FATAL) << "Invalid type for CSRDMatrix: " << TypeInfoToString(type);
224  }
225  return std::unique_ptr<CSRDMatrix>(nullptr);
226 }
227 
228 std::unique_ptr<CSRDMatrix>
229 CSRDMatrix::Create(
230  const char* filename, const char* format, const char* data_type, int nthread, int verbose) {
231  TypeInfo dtype = (data_type ? GetTypeInfoByName(data_type) : TypeInfo::kFloat32);
232  return CreateFromParser(filename, format, dtype, nthread, verbose);
233 }
234 
235 TypeInfo
236 CSRDMatrix::GetElementType() const {
237  return element_type_;
238 }
239 
240 template <typename ElementType>
241 CSRDMatrixImpl<ElementType>::CSRDMatrixImpl(
242  std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
243  size_t num_row, size_t num_col)
244  : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)),
245  num_row(num_row), num_col(num_col)
246 {}
247 
248 template <typename ElementType>
249 size_t
250 CSRDMatrixImpl<ElementType>::GetNumRow() const {
251  return num_row;
252 }
253 
254 template <typename ElementType>
255 size_t
256 CSRDMatrixImpl<ElementType>::GetNumCol() const {
257  return num_col;
258 }
259 
260 template <typename ElementType>
261 size_t
262 CSRDMatrixImpl<ElementType>::GetNumElem() const {
263  return row_ptr.at(num_row);
264 }
265 
266 template <typename ElementType>
267 DMatrixType
268 CSRDMatrixImpl<ElementType>::GetType() const {
269  return DMatrixType::kSparseCSR;
270 }
271 
272 template <typename ElementType>
273 template <typename OutputType>
274 void
275 CSRDMatrixImpl<ElementType>::FillRow(size_t row_id, OutputType* out) const {
276  for (size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
277  out[col_ind[i]] = static_cast<OutputType>(data[i]);
278  }
279 }
280 
281 template <typename ElementType>
282 template <typename OutputType>
283 void
284 CSRDMatrixImpl<ElementType>::ClearRow(size_t row_id, OutputType* out) const {
285  for (size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
286  out[col_ind[i]] = std::numeric_limits<OutputType>::quiet_NaN();
287  }
288 }
289 
290 template class DenseDMatrixImpl<float>;
291 template class DenseDMatrixImpl<double>;
292 template class CSRDMatrixImpl<float>;
293 template class CSRDMatrixImpl<double>;
294 
295 template void CSRDMatrixImpl<float>::FillRow<float>(size_t, float*) const;
296 template void CSRDMatrixImpl<float>::FillRow<double>(size_t, double*) const;
297 template void CSRDMatrixImpl<float>::ClearRow<float>(size_t, float*) const;
298 template void CSRDMatrixImpl<float>::ClearRow<double>(size_t, double*) const;
299 template void CSRDMatrixImpl<double>::FillRow<float>(size_t, float*) const;
300 template void CSRDMatrixImpl<double>::FillRow<double>(size_t, double*) const;
301 template void CSRDMatrixImpl<double>::ClearRow<float>(size_t, float*) const;
302 template void CSRDMatrixImpl<double>::ClearRow<double>(size_t, double*) const;
303 template void DenseDMatrixImpl<float>::FillRow<float>(size_t, float*) const;
304 template void DenseDMatrixImpl<float>::FillRow<double>(size_t, double*) const;
305 template void DenseDMatrixImpl<float>::ClearRow<float>(size_t, float*) const;
306 template void DenseDMatrixImpl<float>::ClearRow<double>(size_t, double*) const;
307 template void DenseDMatrixImpl<double>::FillRow<float>(size_t, float*) const;
308 template void DenseDMatrixImpl<double>::FillRow<double>(size_t, double*) const;
309 template void DenseDMatrixImpl<double>::ClearRow<float>(size_t, float*) const;
310 template void DenseDMatrixImpl<double>::ClearRow<double>(size_t, double*) const;
311 
312 } // namespace treelite
Input data structure of Treelite.
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:16
compatiblity wrapper for systems that don&#39;t support OpenMP