Treelite
data.cc
Go to the documentation of this file.
1 
8 #include <treelite/logging.h>
9 #include <treelite/data.h>
10 #include <memory>
11 #include <limits>
12 #include <cstdint>
13 
14 namespace treelite {
15 
16 template<typename ElementType>
17 std::unique_ptr<DenseDMatrix>
18 DenseDMatrix::Create(
19  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col) {
20  std::unique_ptr<DenseDMatrix> matrix = std::make_unique<DenseDMatrixImpl<ElementType>>(
21  std::move(data), missing_value, num_row, num_col);
22  matrix->element_type_ = TypeToInfo<ElementType>();
23  return matrix;
24 }
25 
26 template<typename ElementType>
27 std::unique_ptr<DenseDMatrix>
28 DenseDMatrix::Create(const void* data, const void* missing_value, size_t num_row, size_t num_col) {
29  auto* data_ptr = static_cast<const ElementType*>(data);
30  const size_t num_elem = num_row * num_col;
31  return DenseDMatrix::Create(std::vector<ElementType>(data_ptr, data_ptr + num_elem),
32  *static_cast<const ElementType*>(missing_value), num_row, num_col);
33 }
34 
35 std::unique_ptr<DenseDMatrix>
36 DenseDMatrix::Create(
37  TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col) {
38  TREELITE_CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid";
39  switch (type) {
40  case TypeInfo::kFloat32:
41  return Create<float>(data, missing_value, num_row, num_col);
42  case TypeInfo::kFloat64:
43  return Create<double>(data, missing_value, num_row, num_col);
44  case TypeInfo::kInvalid:
45  case TypeInfo::kUInt32:
46  default:
47  TREELITE_LOG(FATAL) << "Invalid type for DenseDMatrix: " << TypeInfoToString(type);
48  }
49  return std::unique_ptr<DenseDMatrix>(nullptr);
50 }
51 
53 DenseDMatrix::GetElementType() const {
54  return element_type_;
55 }
56 
57 template<typename ElementType>
58 DenseDMatrixImpl<ElementType>::DenseDMatrixImpl(
59  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col)
60  : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row),
61  num_col(num_col) {}
62 
63 template<typename ElementType>
64 size_t
65 DenseDMatrixImpl<ElementType>::GetNumRow() const {
66  return num_row;
67 }
68 
69 template<typename ElementType>
70 size_t
71 DenseDMatrixImpl<ElementType>::GetNumCol() const {
72  return num_col;
73 }
74 
75 template<typename ElementType>
76 size_t
77 DenseDMatrixImpl<ElementType>::GetNumElem() const {
78  return num_row * num_col;
79 }
80 
81 template<typename ElementType>
82 DMatrixType
83 DenseDMatrixImpl<ElementType>::GetType() const {
84  return DMatrixType::kDense;
85 }
86 
87 template <typename ElementType>
88 template <typename OutputType>
89 void
90 DenseDMatrixImpl<ElementType>::FillRow(size_t row_id, OutputType* out) const {
91  size_t out_idx = 0;
92  size_t in_idx = row_id * num_col;
93  while (out_idx < num_col) {
94  out[out_idx] = static_cast<OutputType>(data[in_idx]);
95  ++out_idx;
96  ++in_idx;
97  }
98 }
99 
100 template <typename ElementType>
101 template <typename OutputType>
102 void
103 DenseDMatrixImpl<ElementType>::ClearRow(size_t row_id, OutputType* out) const {
104  for (size_t i = 0; i < num_col; ++i) {
105  out[i] = std::numeric_limits<OutputType>::quiet_NaN();
106  }
107 }
108 
109 template<typename ElementType>
110 std::unique_ptr<CSRDMatrix>
111 CSRDMatrix::Create(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
112  std::vector<size_t> row_ptr, size_t num_row, size_t num_col) {
113  std::unique_ptr<CSRDMatrix> matrix = std::make_unique<CSRDMatrixImpl<ElementType>>(
114  std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col);
115  matrix->element_type_ = TypeToInfo<ElementType>();
116  return matrix;
117 }
118 
119 template<typename ElementType>
120 std::unique_ptr<CSRDMatrix>
121 CSRDMatrix::Create(const void* data, const uint32_t* col_ind,
122  const size_t* row_ptr, size_t num_row, size_t num_col) {
123  auto* data_ptr = static_cast<const ElementType*>(data);
124  const size_t num_elem = row_ptr[num_row];
125  return CSRDMatrix::Create(
126  std::vector<ElementType>(data_ptr, data_ptr + num_elem),
127  std::vector<uint32_t>(col_ind, col_ind + num_elem),
128  std::vector<size_t>(row_ptr, row_ptr + num_row + 1),
129  num_row,
130  num_col);
131 }
132 
133 std::unique_ptr<CSRDMatrix>
134 CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr,
135  size_t num_row, size_t num_col) {
136  TREELITE_CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid";
137  switch (type) {
138  case TypeInfo::kFloat32:
139  return Create<float>(data, col_ind, row_ptr, num_row, num_col);
140  case TypeInfo::kFloat64:
141  return Create<double>(data, col_ind, row_ptr, num_row, num_col);
142  case TypeInfo::kInvalid:
143  case TypeInfo::kUInt32:
144  default:
145  TREELITE_LOG(FATAL) << "Invalid type for CSRDMatrix: " << TypeInfoToString(type);
146  }
147  return std::unique_ptr<CSRDMatrix>(nullptr);
148 }
149 
150 TypeInfo
151 CSRDMatrix::GetElementType() const {
152  return element_type_;
153 }
154 
155 template <typename ElementType>
156 CSRDMatrixImpl<ElementType>::CSRDMatrixImpl(
157  std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
158  size_t num_row, size_t num_col)
159  : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)),
160  num_row(num_row), num_col(num_col)
161 {}
162 
163 template <typename ElementType>
164 size_t
165 CSRDMatrixImpl<ElementType>::GetNumRow() const {
166  return num_row;
167 }
168 
169 template <typename ElementType>
170 size_t
171 CSRDMatrixImpl<ElementType>::GetNumCol() const {
172  return num_col;
173 }
174 
175 template <typename ElementType>
176 size_t
177 CSRDMatrixImpl<ElementType>::GetNumElem() const {
178  return row_ptr.at(num_row);
179 }
180 
181 template <typename ElementType>
182 DMatrixType
183 CSRDMatrixImpl<ElementType>::GetType() const {
184  return DMatrixType::kSparseCSR;
185 }
186 
187 template <typename ElementType>
188 template <typename OutputType>
189 void
190 CSRDMatrixImpl<ElementType>::FillRow(size_t row_id, OutputType* out) const {
191  for (size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
192  out[col_ind[i]] = static_cast<OutputType>(data[i]);
193  }
194 }
195 
196 template <typename ElementType>
197 template <typename OutputType>
198 void
199 CSRDMatrixImpl<ElementType>::ClearRow(size_t row_id, OutputType* out) const {
200  for (size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
201  out[col_ind[i]] = std::numeric_limits<OutputType>::quiet_NaN();
202  }
203 }
204 
205 template class DenseDMatrixImpl<float>;
206 template class DenseDMatrixImpl<double>;
207 template class CSRDMatrixImpl<float>;
208 template class CSRDMatrixImpl<double>;
209 
210 template void CSRDMatrixImpl<float>::FillRow<float>(size_t, float*) const;
211 template void CSRDMatrixImpl<float>::FillRow<double>(size_t, double*) const;
212 template void CSRDMatrixImpl<float>::ClearRow<float>(size_t, float*) const;
213 template void CSRDMatrixImpl<float>::ClearRow<double>(size_t, double*) const;
214 template void CSRDMatrixImpl<double>::FillRow<float>(size_t, float*) const;
215 template void CSRDMatrixImpl<double>::FillRow<double>(size_t, double*) const;
216 template void CSRDMatrixImpl<double>::ClearRow<float>(size_t, float*) const;
217 template void CSRDMatrixImpl<double>::ClearRow<double>(size_t, double*) const;
218 template void DenseDMatrixImpl<float>::FillRow<float>(size_t, float*) const;
219 template void DenseDMatrixImpl<float>::FillRow<double>(size_t, double*) const;
220 template void DenseDMatrixImpl<float>::ClearRow<float>(size_t, float*) const;
221 template void DenseDMatrixImpl<float>::ClearRow<double>(size_t, double*) const;
222 template void DenseDMatrixImpl<double>::FillRow<float>(size_t, float*) const;
223 template void DenseDMatrixImpl<double>::FillRow<double>(size_t, double*) const;
224 template void DenseDMatrixImpl<double>::ClearRow<float>(size_t, float*) const;
225 template void DenseDMatrixImpl<double>::ClearRow<double>(size_t, double*) const;
226 
227 } // namespace treelite
Input data structure of Treelite.
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39