7 #ifndef TREELITE_DATA_H_ 8 #define TREELITE_DATA_H_ 10 #include <dmlc/data.h> 13 #include <type_traits> 18 enum class DMatrixType : uint8_t {
25 virtual size_t GetNumRow()
const = 0;
26 virtual size_t GetNumCol()
const = 0;
27 virtual size_t GetNumElem()
const = 0;
28 virtual DMatrixType GetType()
const = 0;
29 virtual TypeInfo GetElementType()
const = 0;
38 template<
typename ElementType>
39 static std::unique_ptr<DenseDMatrix> Create(
40 std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
size_t num_col);
41 template<
typename ElementType>
42 static std::unique_ptr<DenseDMatrix> Create(
43 const void* data,
const void* missing_value,
size_t num_row,
size_t num_col);
44 static std::unique_ptr<DenseDMatrix> Create(
45 TypeInfo type,
const void* data,
const void* missing_value,
size_t num_row,
size_t num_col);
46 size_t GetNumRow()
const override = 0;
47 size_t GetNumCol()
const override = 0;
48 size_t GetNumElem()
const override = 0;
49 DMatrixType GetType()
const override = 0;
50 TypeInfo GetElementType()
const override;
53 template<
typename ElementType>
57 std::vector<ElementType>
data;
66 DenseDMatrixImpl(std::vector<ElementType> data, ElementType missing_value,
size_t num_row,
74 size_t GetNumRow() const override;
75 size_t GetNumCol() const override;
76 size_t GetNumElem() const override;
77 DMatrixType GetType() const override;
79 template <typename OutputType>
80 void FillRow(
size_t row_id, OutputType* out) const;
81 template <typename OutputType>
82 void ClearRow(
size_t row_id, OutputType* out) const;
91 template<
typename ElementType>
92 static std::unique_ptr<CSRDMatrix> Create(
93 std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
94 size_t num_row,
size_t num_col);
95 template<
typename ElementType>
96 static std::unique_ptr<CSRDMatrix> Create(
97 const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
size_t num_row,
99 static std::unique_ptr<CSRDMatrix> Create(
100 TypeInfo type,
const void* data,
const uint32_t* col_ind,
const size_t* row_ptr,
101 size_t num_row,
size_t num_col);
102 static std::unique_ptr<CSRDMatrix> Create(
103 const char* filename,
const char* format,
const char* data_type,
int nthread,
int verbose);
104 size_t GetNumRow()
const override = 0;
105 size_t GetNumCol()
const override = 0;
106 size_t GetNumElem()
const override = 0;
107 DMatrixType GetType()
const override = 0;
108 TypeInfo GetElementType()
const override;
111 template<
typename ElementType>
126 CSRDMatrixImpl(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
127 std::vector<size_t> row_ptr,
size_t num_row,
size_t num_col);
133 size_t GetNumRow() const override;
134 size_t GetNumCol() const override;
135 size_t GetNumElem() const override;
136 DMatrixType GetType() const override;
138 template <typename OutputType>
139 void FillRow(
size_t row_id, OutputType* out) const;
140 template <typename OutputType>
141 void ClearRow(
size_t row_id, OutputType* out) const;
148 #endif // TREELITE_DATA_H_ size_t num_col
number of columns (i.e. # of features used)
ElementType missing_value
value representing the missing value (usually NaN)
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
std::vector< ElementType > data
feature values
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< ElementType > data
feature values
Defines TypeInfo class and utilities.
size_t num_row
number of rows
size_t num_row
number of rows
size_t num_col
number of columns (i.e. # of features used)