Treelite
data.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_DATA_H_
8 #define TREELITE_DATA_H_
9 
10 #include <treelite/typeinfo.h>
11 #include <vector>
12 #include <type_traits>
13 #include <memory>
14 
15 namespace treelite {
16 
17 enum class DMatrixType : uint8_t {
18  kDense = 0,
19  kSparseCSR = 1
20 };
21 
22 class DMatrix {
23  public:
24  virtual size_t GetNumRow() const = 0;
25  virtual size_t GetNumCol() const = 0;
26  virtual size_t GetNumElem() const = 0;
27  virtual DMatrixType GetType() const = 0;
28  virtual TypeInfo GetElementType() const = 0;
29  DMatrix() = default;
30  virtual ~DMatrix() = default;
31 };
32 
33 class DenseDMatrix : public DMatrix {
34  private:
35  TypeInfo element_type_;
36  public:
37  template<typename ElementType>
38  static std::unique_ptr<DenseDMatrix> Create(
39  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col);
40  template<typename ElementType>
41  static std::unique_ptr<DenseDMatrix> Create(
42  const void* data, const void* missing_value, size_t num_row, size_t num_col);
43  static std::unique_ptr<DenseDMatrix> Create(
44  TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col);
45  size_t GetNumRow() const override = 0;
46  size_t GetNumCol() const override = 0;
47  size_t GetNumElem() const override = 0;
48  DMatrixType GetType() const override = 0;
49  TypeInfo GetElementType() const override;
50 };
51 
52 template<typename ElementType>
54  public:
56  std::vector<ElementType> data;
58  ElementType missing_value;
60  size_t num_row;
62  size_t num_col;
63 
64  DenseDMatrixImpl() = delete;
65  DenseDMatrixImpl(std::vector<ElementType> data, ElementType missing_value, size_t num_row,
66  size_t num_col);
67  ~DenseDMatrixImpl() = default;
68  DenseDMatrixImpl(const DenseDMatrixImpl&) = default;
69  DenseDMatrixImpl(DenseDMatrixImpl&&) noexcept = default;
70  DenseDMatrixImpl& operator=(const DenseDMatrixImpl&) = default;
71  DenseDMatrixImpl& operator=(DenseDMatrixImpl&&) noexcept = default;
72 
73  size_t GetNumRow() const override;
74  size_t GetNumCol() const override;
75  size_t GetNumElem() const override;
76  DMatrixType GetType() const override;
77 
78  template <typename OutputType>
79  void FillRow(size_t row_id, OutputType* out) const;
80  template <typename OutputType>
81  void ClearRow(size_t row_id, OutputType* out) const;
82 
83  friend class DenseDMatrix;
84 };
85 
86 class CSRDMatrix : public DMatrix {
87  private:
88  TypeInfo element_type_;
89  public:
90  template<typename ElementType>
91  static std::unique_ptr<CSRDMatrix> Create(
92  std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
93  size_t num_row, size_t num_col);
94  template<typename ElementType>
95  static std::unique_ptr<CSRDMatrix> Create(
96  const void* data, const uint32_t* col_ind, const size_t* row_ptr, size_t num_row,
97  size_t num_col);
98  static std::unique_ptr<CSRDMatrix> Create(
99  TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr,
100  size_t num_row, size_t num_col);
101  size_t GetNumRow() const override = 0;
102  size_t GetNumCol() const override = 0;
103  size_t GetNumElem() const override = 0;
104  DMatrixType GetType() const override = 0;
105  TypeInfo GetElementType() const override;
106 };
107 
108 template<typename ElementType>
109 class CSRDMatrixImpl : public CSRDMatrix {
110  public:
112  std::vector<ElementType> data;
114  std::vector<uint32_t> col_ind;
116  std::vector<size_t> row_ptr;
118  size_t num_row;
120  size_t num_col;
121 
122  CSRDMatrixImpl() = delete;
123  CSRDMatrixImpl(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
124  std::vector<size_t> row_ptr, size_t num_row, size_t num_col);
125  CSRDMatrixImpl(const CSRDMatrixImpl&) = default;
126  CSRDMatrixImpl(CSRDMatrixImpl&&) noexcept = default;
127  CSRDMatrixImpl& operator=(const CSRDMatrixImpl&) = default;
128  CSRDMatrixImpl& operator=(CSRDMatrixImpl&&) noexcept = default;
129 
130  size_t GetNumRow() const override;
131  size_t GetNumCol() const override;
132  size_t GetNumElem() const override;
133  DMatrixType GetType() const override;
134 
135  template <typename OutputType>
136  void FillRow(size_t row_id, OutputType* out) const;
137  template <typename OutputType>
138  void ClearRow(size_t row_id, OutputType* out) const;
139 
140  friend class CSRDMatrix;
141 };
142 
143 } // namespace treelite
144 
145 #endif // TREELITE_DATA_H_
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:120
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:58
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:116
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:114
std::vector< ElementType > data
feature values
Definition: data.h:56
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< ElementType > data
feature values
Definition: data.h:112
Defines TypeInfo class and utilities.
size_t num_row
number of rows
Definition: data.h:60
size_t num_row
number of rows
Definition: data.h:118
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:62