16 template<
typename ElementType>
17 std::unique_ptr<DenseDMatrix>
19 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col) {
20 std::unique_ptr<DenseDMatrix> matrix = std::make_unique<DenseDMatrixImpl<ElementType>>(
21 std::move(data), missing_value, num_row, num_col);
22 matrix->element_type_ = TypeToInfo<ElementType>();
26 template<
typename ElementType>
27 std::unique_ptr<DenseDMatrix>
28 DenseDMatrix::Create(
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col) {
29 auto* data_ptr =
static_cast<const ElementType*
>(data);
30 const size_t num_elem = num_row * num_col;
31 return DenseDMatrix::Create(std::vector<ElementType>(data_ptr, data_ptr + num_elem),
32 *static_cast<const ElementType*>(missing_value), num_row, num_col);
35 std::unique_ptr<DenseDMatrix>
37 TypeInfo type,
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col) {
38 TREELITE_CHECK(type != TypeInfo::kInvalid) <<
"ElementType cannot be invalid";
40 case TypeInfo::kFloat32:
41 return Create<float>(data, missing_value, num_row, num_col);
42 case TypeInfo::kFloat64:
43 return Create<double>(data, missing_value, num_row, num_col);
44 case TypeInfo::kInvalid:
45 case TypeInfo::kUInt32:
47 TREELITE_LOG(FATAL) <<
"Invalid type for DenseDMatrix: " <<
TypeInfoToString(type);
49 return std::unique_ptr<DenseDMatrix>(
nullptr);
53 DenseDMatrix::GetElementType()
const {
57 template<
typename ElementType>
58 DenseDMatrixImpl<ElementType>::DenseDMatrixImpl(
59 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col)
60 : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row),
63 template<
typename ElementType>
65 DenseDMatrixImpl<ElementType>::GetNumRow()
const {
69 template<
typename ElementType>
71 DenseDMatrixImpl<ElementType>::GetNumCol()
const {
75 template<
typename ElementType>
77 DenseDMatrixImpl<ElementType>::GetNumElem()
const {
78 return num_row * num_col;
81 template<
typename ElementType>
83 DenseDMatrixImpl<ElementType>::GetType()
const {
84 return DMatrixType::kDense;
87 template <
typename ElementType>
88 template <
typename OutputType>
90 DenseDMatrixImpl<ElementType>::FillRow(
size_t row_id, OutputType* out)
const {
92 size_t in_idx = row_id * num_col;
93 while (out_idx < num_col) {
94 out[out_idx] =
static_cast<OutputType
>(data[in_idx]);
100 template <
typename ElementType>
101 template <
typename OutputType>
103 DenseDMatrixImpl<ElementType>::ClearRow(
size_t row_id, OutputType* out)
const {
104 for (
size_t i = 0; i < num_col; ++i) {
105 out[i] = std::numeric_limits<OutputType>::quiet_NaN();
109 template<
typename ElementType>
110 std::unique_ptr<CSRDMatrix>
111 CSRDMatrix::Create(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
112 std::vector<size_t> row_ptr,
size_t num_row,
size_t num_col) {
113 std::unique_ptr<CSRDMatrix> matrix = std::make_unique<CSRDMatrixImpl<ElementType>>(
114 std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col);
115 matrix->element_type_ = TypeToInfo<ElementType>();
119 template<
typename ElementType>
120 std::unique_ptr<CSRDMatrix>
121 CSRDMatrix::Create(
const void* data,
const uint32_t* col_ind,
122 const size_t* row_ptr,
size_t num_row,
size_t num_col) {
123 auto* data_ptr =
static_cast<const ElementType*
>(data);
124 const size_t num_elem = row_ptr[num_row];
125 return CSRDMatrix::Create(
126 std::vector<ElementType>(data_ptr, data_ptr + num_elem),
127 std::vector<uint32_t>(col_ind, col_ind + num_elem),
128 std::vector<size_t>(row_ptr, row_ptr + num_row + 1),
133 std::unique_ptr<CSRDMatrix>
134 CSRDMatrix::Create(TypeInfo type,
const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
135 size_t num_row,
size_t num_col) {
136 TREELITE_CHECK(type != TypeInfo::kInvalid) <<
"ElementType cannot be invalid";
138 case TypeInfo::kFloat32:
139 return Create<float>(data, col_ind, row_ptr, num_row, num_col);
140 case TypeInfo::kFloat64:
141 return Create<double>(data, col_ind, row_ptr, num_row, num_col);
142 case TypeInfo::kInvalid:
143 case TypeInfo::kUInt32:
145 TREELITE_LOG(FATAL) <<
"Invalid type for CSRDMatrix: " <<
TypeInfoToString(type);
147 return std::unique_ptr<CSRDMatrix>(
nullptr);
151 CSRDMatrix::GetElementType()
const {
152 return element_type_;
155 template <
typename ElementType>
156 CSRDMatrixImpl<ElementType>::CSRDMatrixImpl(
157 std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
158 size_t num_row,
size_t num_col)
159 : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)),
160 num_row(num_row), num_col(num_col)
163 template <
typename ElementType>
165 CSRDMatrixImpl<ElementType>::GetNumRow()
const {
169 template <
typename ElementType>
171 CSRDMatrixImpl<ElementType>::GetNumCol()
const {
175 template <
typename ElementType>
177 CSRDMatrixImpl<ElementType>::GetNumElem()
const {
178 return row_ptr.at(num_row);
181 template <
typename ElementType>
183 CSRDMatrixImpl<ElementType>::GetType()
const {
184 return DMatrixType::kSparseCSR;
187 template <
typename ElementType>
188 template <
typename OutputType>
190 CSRDMatrixImpl<ElementType>::FillRow(
size_t row_id, OutputType* out)
const {
191 for (
size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
192 out[col_ind[i]] =
static_cast<OutputType
>(data[i]);
196 template <
typename ElementType>
197 template <
typename OutputType>
199 CSRDMatrixImpl<ElementType>::ClearRow(
size_t row_id, OutputType* out)
const {
200 for (
size_t i = row_ptr[row_id]; i < row_ptr[row_id + 1]; ++i) {
201 out[col_ind[i]] = std::numeric_limits<OutputType>::quiet_NaN();
205 template class DenseDMatrixImpl<float>;
206 template class DenseDMatrixImpl<double>;
207 template class CSRDMatrixImpl<float>;
208 template class CSRDMatrixImpl<double>;
210 template void CSRDMatrixImpl<float>::FillRow<
float>(size_t,
float*)
const;
211 template void CSRDMatrixImpl<float>::FillRow<
double>(size_t,
double*)
const;
212 template void CSRDMatrixImpl<float>::ClearRow<
float>(size_t,
float*)
const;
213 template void CSRDMatrixImpl<float>::ClearRow<
double>(size_t,
double*)
const;
214 template void CSRDMatrixImpl<double>::FillRow<
float>(size_t,
float*)
const;
215 template void CSRDMatrixImpl<double>::FillRow<
double>(size_t,
double*)
const;
216 template void CSRDMatrixImpl<double>::ClearRow<
float>(size_t,
float*)
const;
217 template void CSRDMatrixImpl<double>::ClearRow<
double>(size_t,
double*)
const;
218 template void DenseDMatrixImpl<float>::FillRow<
float>(size_t,
float*)
const;
219 template void DenseDMatrixImpl<float>::FillRow<
double>(size_t,
double*)
const;
220 template void DenseDMatrixImpl<float>::ClearRow<
float>(size_t,
float*)
const;
221 template void DenseDMatrixImpl<float>::ClearRow<
double>(size_t,
double*)
const;
222 template void DenseDMatrixImpl<double>::FillRow<
float>(size_t,
float*)
const;
223 template void DenseDMatrixImpl<double>::FillRow<
double>(size_t,
double*)
const;
224 template void DenseDMatrixImpl<double>::ClearRow<
float>(size_t,
float*)
const;
225 template void DenseDMatrixImpl<double>::ClearRow<
double>(size_t,
double*)
const;
Input data structure of Treelite.
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.