7 #ifndef TREELITE_DATA_H_ 8 #define TREELITE_DATA_H_ 12 #include <type_traits> 17 enum class DMatrixType : uint8_t {
24 virtual size_t GetNumRow()
const = 0;
25 virtual size_t GetNumCol()
const = 0;
26 virtual size_t GetNumElem()
const = 0;
27 virtual DMatrixType GetType()
const = 0;
28 virtual TypeInfo GetElementType()
const = 0;
37 template<
typename ElementType>
38 static std::unique_ptr<DenseDMatrix> Create(
39 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col);
40 template<
typename ElementType>
41 static std::unique_ptr<DenseDMatrix> Create(
42 const void* data,
const void* missing_value,
size_t num_row,
size_t num_col);
43 static std::unique_ptr<DenseDMatrix> Create(
44 TypeInfo type,
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col);
45 size_t GetNumRow()
const override = 0;
46 size_t GetNumCol()
const override = 0;
47 size_t GetNumElem()
const override = 0;
48 DMatrixType GetType()
const override = 0;
49 TypeInfo GetElementType()
const override;
52 template<
typename ElementType>
56 std::vector<ElementType>
data;
65 DenseDMatrixImpl(std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
73 size_t GetNumRow() const override;
74 size_t GetNumCol() const override;
75 size_t GetNumElem() const override;
76 DMatrixType GetType() const override;
78 template <typename OutputType>
79 void FillRow(
size_t row_id, OutputType* out) const;
80 template <typename OutputType>
81 void ClearRow(
size_t row_id, OutputType* out) const;
90 template<
typename ElementType>
91 static std::unique_ptr<CSRDMatrix> Create(
92 std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
93 size_t num_row,
size_t num_col);
94 template<
typename ElementType>
95 static std::unique_ptr<CSRDMatrix> Create(
96 const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
size_t num_row,
98 static std::unique_ptr<CSRDMatrix> Create(
99 TypeInfo type,
const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
100 size_t num_row,
size_t num_col);
101 size_t GetNumRow()
const override = 0;
102 size_t GetNumCol()
const override = 0;
103 size_t GetNumElem()
const override = 0;
104 DMatrixType GetType()
const override = 0;
105 TypeInfo GetElementType()
const override;
108 template<
typename ElementType>
123 CSRDMatrixImpl(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
124 std::vector<size_t> row_ptr,
size_t num_row,
size_t num_col);
130 size_t GetNumRow() const override;
131 size_t GetNumCol() const override;
132 size_t GetNumElem() const override;
133 DMatrixType GetType() const override;
135 template <typename OutputType>
136 void FillRow(
size_t row_id, OutputType* out) const;
137 template <typename OutputType>
138 void ClearRow(
size_t row_id, OutputType* out) const;
145 #endif // TREELITE_DATA_H_ size_t num_col
number of columns (i.e. # of features used)
ElementType missing_value
value representing the missing value (usually NaN)
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
Defines TypeInfo class and utilities.
size_t num_row
number of rows
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)