Treelite
data.cc
Go to the documentation of this file.
1 
8 #include <treelite/data.h>
9 #include <treelite/omp.h>
10 #include <memory>
11 #include <limits>
12 #include <cstdint>
13 
14 namespace {
15 
16 template <typename ElementType, typename DMLCParserDType>
17 inline static std::unique_ptr<treelite::CSRDMatrix> CreateFromParserImpl(
18  const char* filename, const char* format, int nthread, int verbose) {
19  std::unique_ptr<dmlc::Parser<uint32_t, DMLCParserDType>> parser(
20  dmlc::Parser<uint32_t, DMLCParserDType>::Create(filename, 0, 1, format));
21 
22  const int max_thread = omp_get_max_threads();
23  nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread);
24 
25  std::vector<ElementType> data;
26  std::vector<uint32_t> col_ind;
27  std::vector<size_t> row_ptr;
28  row_ptr.resize(1, 0);
29  size_t num_row = 0;
30  size_t num_col = 0;
31  size_t num_elem = 0;
32 
33  std::vector<size_t> max_col_ind(nthread, 0);
34  parser->BeforeFirst();
35  while (parser->Next()) {
36  const dmlc::RowBlock<uint32_t, DMLCParserDType>& batch = parser->Value();
37  num_row += batch.size;
38  num_elem += batch.offset[batch.size];
39  const size_t top = data.size();
40  data.resize(top + batch.offset[batch.size] - batch.offset[0]);
41  col_ind.resize(top + batch.offset[batch.size] - batch.offset[0]);
42  CHECK_LT(static_cast<int64_t>(batch.offset[batch.size]),
43  std::numeric_limits<int64_t>::max());
44  #pragma omp parallel for schedule(static) num_threads(nthread)
45  for (int64_t i = static_cast<int64_t>(batch.offset[0]);
46  i < static_cast<int64_t>(batch.offset[batch.size]); ++i) {
47  const int tid = omp_get_thread_num();
48  const uint32_t index = batch.index[i];
49  const ElementType fvalue
50  = ((batch.value == nullptr) ? static_cast<ElementType>(1)
51  : static_cast<ElementType>(batch.value[i]));
52  const size_t offset = top + i - batch.offset[0];
53  data[offset] = fvalue;
54  col_ind[offset] = index;
55  max_col_ind[tid] = std::max(max_col_ind[tid], static_cast<size_t>(index));
56  }
57  const size_t rtop = row_ptr.size();
58  row_ptr.resize(rtop + batch.size);
59  CHECK_LT(static_cast<int64_t>(batch.size), std::numeric_limits<int64_t>::max());
60  #pragma omp parallel for schedule(static) num_threads(nthread)
61  for (int64_t i = 0; i < static_cast<int64_t>(batch.size); ++i) {
62  row_ptr[rtop + i] = row_ptr[rtop - 1] + batch.offset[i + 1] - batch.offset[0];
63  }
64  if (verbose > 0) {
65  LOG(INFO) << num_row << " rows read into memory";
66  }
67  }
68  num_col = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1;
69  return treelite::CSRDMatrix::Create(std::move(data), std::move(col_ind), std::move(row_ptr),
70  num_row, num_col);
71 }
72 
73 std::unique_ptr<treelite::CSRDMatrix>
74 CreateFromParser(
75  const char* filename, const char* format, treelite::TypeInfo dtype, int nthread, int verbose) {
76  switch (dtype) {
77  case treelite::TypeInfo::kFloat32:
78  return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
79  case treelite::TypeInfo::kFloat64:
80  return CreateFromParserImpl<double, float>(filename, format, nthread, verbose);
81  case treelite::TypeInfo::kUInt32:
82  return CreateFromParserImpl<uint32_t, int64_t>(filename, format, nthread, verbose);
83  default:
84  LOG(FATAL) << "Unrecognized TypeInfo: " << treelite::TypeInfoToString(dtype);
85  }
86  return CreateFromParserImpl<float, float>(filename, format, nthread, verbose);
87  // avoid missing value warning
88 }
89 
90 } // anonymous namespace
91 
92 namespace treelite {
93 
94 template<typename ElementType>
95 std::unique_ptr<DenseDMatrix>
96 DenseDMatrix::Create(
97  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col) {
98  std::unique_ptr<DenseDMatrix> matrix = std::make_unique<DenseDMatrixImpl<ElementType>>(
99  std::move(data), missing_value, num_row, num_col);
100  matrix->element_type_ = TypeToInfo<ElementType>();
101  return matrix;
102 }
103 
104 template<typename ElementType>
105 std::unique_ptr<DenseDMatrix>
106 DenseDMatrix::Create(const void* data, const void* missing_value, size_t num_row, size_t num_col) {
107  auto* data_ptr = static_cast<const ElementType*>(data);
108  const size_t num_elem = num_row * num_col;
109  return DenseDMatrix::Create(std::vector<ElementType>(data_ptr, data_ptr + num_elem),
110  *static_cast<const ElementType*>(missing_value), num_row, num_col);
111 }
112 
113 std::unique_ptr<DenseDMatrix>
114 DenseDMatrix::Create(
115  TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col) {
116  CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid";
117  switch (type) {
118  case TypeInfo::kFloat32:
119  return Create<float>(data, missing_value, num_row, num_col);
120  case TypeInfo::kFloat64:
121  return Create<double>(data, missing_value, num_row, num_col);
122  case TypeInfo::kInvalid:
123  case TypeInfo::kUInt32:
124  default:
125  LOG(FATAL) << "Invalid type for DenseDMatrix: " << TypeInfoToString(type);
126  }
127  return std::unique_ptr<DenseDMatrix>(nullptr);
128 }
129 
130 TypeInfo
131 DenseDMatrix::GetElementType() const {
132  return element_type_;
133 }
134 
135 template<typename ElementType>
136 DenseDMatrixImpl<ElementType>::DenseDMatrixImpl(
137  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col)
138  : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row),
139  num_col(num_col) {}
140 
141 template<typename ElementType>
142 size_t
143 DenseDMatrixImpl<ElementType>::GetNumRow() const {
144  return num_row;
145 }
146 
147 template<typename ElementType>
148 size_t
149 DenseDMatrixImpl<ElementType>::GetNumCol() const {
150  return num_col;
151 }
152 
153 template<typename ElementType>
154 size_t
155 DenseDMatrixImpl<ElementType>::GetNumElem() const {
156  return num_row * num_col;
157 }
158 
159 template<typename ElementType>
160 DMatrixType
161 DenseDMatrixImpl<ElementType>::GetType() const {
162  return DMatrixType::kDense;
163 }
164 
165 template<typename ElementType>
166 std::unique_ptr<CSRDMatrix>
167 CSRDMatrix::Create(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
168  std::vector<size_t> row_ptr, size_t num_row, size_t num_col) {
169  std::unique_ptr<CSRDMatrix> matrix = std::make_unique<CSRDMatrixImpl<ElementType>>(
170  std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col);
171  matrix->element_type_ = TypeToInfo<ElementType>();
172  return matrix;
173 }
174 
175 template<typename ElementType>
176 std::unique_ptr<CSRDMatrix>
177 CSRDMatrix::Create(const void* data, const uint32_t* col_ind,
178  const size_t* row_ptr, size_t num_row, size_t num_col) {
179  auto* data_ptr = static_cast<const ElementType*>(data);
180  const size_t num_elem = row_ptr[num_row];
181  return CSRDMatrix::Create(
182  std::vector<ElementType>(data_ptr, data_ptr + num_elem),
183  std::vector<uint32_t>(col_ind, col_ind + num_elem),
184  std::vector<size_t>(row_ptr, row_ptr + num_row + 1),
185  num_row,
186  num_col);
187 }
188 
189 std::unique_ptr<CSRDMatrix>
190 CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr,
191  size_t num_row, size_t num_col) {
192  CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid";
193  switch (type) {
194  case TypeInfo::kFloat32:
195  return Create<float>(data, col_ind, row_ptr, num_row, num_col);
196  case TypeInfo::kFloat64:
197  return Create<double>(data, col_ind, row_ptr, num_row, num_col);
198  case TypeInfo::kInvalid:
199  case TypeInfo::kUInt32:
200  default:
201  LOG(FATAL) << "Invalid type for CSRDMatrix: " << TypeInfoToString(type);
202  }
203  return std::unique_ptr<CSRDMatrix>(nullptr);
204 }
205 
206 std::unique_ptr<CSRDMatrix>
207 CSRDMatrix::Create(
208  const char* filename, const char* format, const char* data_type, int nthread, int verbose) {
209  TypeInfo dtype = (data_type ? GetTypeInfoByName(data_type) : TypeInfo::kFloat32);
210  return CreateFromParser(filename, format, dtype, nthread, verbose);
211 }
212 
213 TypeInfo
214 CSRDMatrix::GetElementType() const {
215  return element_type_;
216 }
217 
218 template <typename ElementType>
219 CSRDMatrixImpl<ElementType>::CSRDMatrixImpl(
220  std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
221  size_t num_row, size_t num_col)
222  : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)),
223  num_row(num_row), num_col(num_col)
224 {}
225 
226 template <typename ElementType>
227 size_t
228 CSRDMatrixImpl<ElementType>::GetNumRow() const {
229  return num_row;
230 }
231 
232 template <typename ElementType>
233 size_t
234 CSRDMatrixImpl<ElementType>::GetNumCol() const {
235  return num_col;
236 }
237 
238 template <typename ElementType>
239 size_t
240 CSRDMatrixImpl<ElementType>::GetNumElem() const {
241  return row_ptr.at(num_row);
242 }
243 
244 template <typename ElementType>
245 DMatrixType
246 CSRDMatrixImpl<ElementType>::GetType() const {
247  return DMatrixType::kSparseCSR;
248 }
249 
250 template class DenseDMatrixImpl<float>;
251 template class DenseDMatrixImpl<double>;
252 template class CSRDMatrixImpl<float>;
253 template class CSRDMatrixImpl<double>;
254 
255 } // namespace treelite
Input data structure of Treelite.
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
TypeInfo GetTypeInfoByName(const std::string &str)
conversion table from string to TypeInfo, defined in tables.cc
Definition: typeinfo.cc:16
compatiblity wrapper for systems that don&#39;t support OpenMP