7 #ifndef TREELITE_DATA_H_ 8 #define TREELITE_DATA_H_ 10 #include <dmlc/data.h> 13 #include <type_traits> 18 enum class DMatrixType : uint8_t {
25 virtual size_t GetNumRow()
const = 0;
26 virtual size_t GetNumCol()
const = 0;
27 virtual size_t GetNumElem()
const = 0;
28 virtual DMatrixType GetType()
const = 0;
29 virtual TypeInfo GetElementType()
const = 0;
38 template<
typename ElementType>
39 static std::unique_ptr<DenseDMatrix> Create(
40 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col);
41 template<
typename ElementType>
42 static std::unique_ptr<DenseDMatrix> Create(
43 const void* data,
const void* missing_value,
size_t num_row,
size_t num_col);
44 static std::unique_ptr<DenseDMatrix> Create(
45 TypeInfo type,
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col);
46 size_t GetNumRow()
const override = 0;
47 size_t GetNumCol()
const override = 0;
48 size_t GetNumElem()
const override = 0;
49 DMatrixType GetType()
const override = 0;
50 TypeInfo GetElementType()
const override;
53 template<
typename ElementType>
57 std::vector<ElementType>
data;
66 DenseDMatrixImpl(std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
74 size_t GetNumRow() const override;
75 size_t GetNumCol() const override;
76 size_t GetNumElem() const override;
77 DMatrixType GetType() const override;
86 template<
typename ElementType>
87 static std::unique_ptr<CSRDMatrix> Create(
88 std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
89 size_t num_row,
size_t num_col);
90 template<
typename ElementType>
91 static std::unique_ptr<CSRDMatrix> Create(
92 const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
size_t num_row,
94 static std::unique_ptr<CSRDMatrix> Create(
95 TypeInfo type,
const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
96 size_t num_row,
size_t num_col);
97 static std::unique_ptr<CSRDMatrix> Create(
98 const char* filename,
const char* format,
const char* data_type,
int nthread,
int verbose);
99 size_t GetNumRow()
const override = 0;
100 size_t GetNumCol()
const override = 0;
101 size_t GetNumElem()
const override = 0;
102 DMatrixType GetType()
const override = 0;
103 TypeInfo GetElementType()
const override;
106 template<
typename ElementType>
121 CSRDMatrixImpl(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
122 std::vector<size_t> row_ptr,
size_t num_row,
size_t num_col);
128 size_t GetNumRow() const override;
129 size_t GetNumCol() const override;
130 size_t GetNumElem() const override;
131 DMatrixType GetType() const override;
138 #endif // TREELITE_DATA_H_ size_t num_col
number of columns (i.e. # of features used)
ElementType missing_value
value representing the missing value (usually NaN)
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
Defines TypeInfo class and utilities.
size_t num_row
number of rows
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)