Treelite
data.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_DATA_H_
8 #define TREELITE_DATA_H_
9 
10 #include <dmlc/data.h>
11 #include <treelite/typeinfo.h>
12 #include <vector>
13 #include <type_traits>
14 #include <memory>
15 
16 namespace treelite {
17 
18 enum class DMatrixType : uint8_t {
19  kDense = 0,
20  kSparseCSR = 1
21 };
22 
23 class DMatrix {
24  public:
25  virtual size_t GetNumRow() const = 0;
26  virtual size_t GetNumCol() const = 0;
27  virtual size_t GetNumElem() const = 0;
28  virtual DMatrixType GetType() const = 0;
29  virtual TypeInfo GetElementType() const = 0;
30  DMatrix() = default;
31  virtual ~DMatrix() = default;
32 };
33 
34 class DenseDMatrix : public DMatrix {
35  private:
36  TypeInfo element_type_;
37  public:
38  template<typename ElementType>
39  static std::unique_ptr<DenseDMatrix> Create(
40  std::vector<ElementType> data, ElementType missing_value, size_t num_row, size_t num_col);
41  template<typename ElementType>
42  static std::unique_ptr<DenseDMatrix> Create(
43  const void* data, const void* missing_value, size_t num_row, size_t num_col);
44  static std::unique_ptr<DenseDMatrix> Create(
45  TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col);
46  size_t GetNumRow() const override = 0;
47  size_t GetNumCol() const override = 0;
48  size_t GetNumElem() const override = 0;
49  DMatrixType GetType() const override = 0;
50  TypeInfo GetElementType() const override;
51 };
52 
53 template<typename ElementType>
55  public:
57  std::vector<ElementType> data;
59  ElementType missing_value;
61  size_t num_row;
63  size_t num_col;
64 
65  DenseDMatrixImpl() = delete;
66  DenseDMatrixImpl(std::vector<ElementType> data, ElementType missing_value, size_t num_row,
67  size_t num_col);
68  ~DenseDMatrixImpl() = default;
69  DenseDMatrixImpl(const DenseDMatrixImpl&) = default;
70  DenseDMatrixImpl(DenseDMatrixImpl&&) noexcept = default;
71  DenseDMatrixImpl& operator=(const DenseDMatrixImpl&) = default;
72  DenseDMatrixImpl& operator=(DenseDMatrixImpl&&) noexcept = default;
73 
74  size_t GetNumRow() const override;
75  size_t GetNumCol() const override;
76  size_t GetNumElem() const override;
77  DMatrixType GetType() const override;
78 
79  template <typename OutputType>
80  void FillRow(size_t row_id, OutputType* out) const;
81  template <typename OutputType>
82  void ClearRow(size_t row_id, OutputType* out) const;
83 
84  friend class DenseDMatrix;
85 };
86 
87 class CSRDMatrix : public DMatrix {
88  private:
89  TypeInfo element_type_;
90  public:
91  template<typename ElementType>
92  static std::unique_ptr<CSRDMatrix> Create(
93  std::vector<ElementType> data, std::vector<uint32_t> col_ind, std::vector<size_t> row_ptr,
94  size_t num_row, size_t num_col);
95  template<typename ElementType>
96  static std::unique_ptr<CSRDMatrix> Create(
97  const void* data, const uint32_t* col_ind, const size_t* row_ptr, size_t num_row,
98  size_t num_col);
99  static std::unique_ptr<CSRDMatrix> Create(
100  TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr,
101  size_t num_row, size_t num_col);
102  static std::unique_ptr<CSRDMatrix> Create(
103  const char* filename, const char* format, const char* data_type, int nthread, int verbose);
104  size_t GetNumRow() const override = 0;
105  size_t GetNumCol() const override = 0;
106  size_t GetNumElem() const override = 0;
107  DMatrixType GetType() const override = 0;
108  TypeInfo GetElementType() const override;
109 };
110 
111 template<typename ElementType>
112 class CSRDMatrixImpl : public CSRDMatrix {
113  public:
115  std::vector<ElementType> data;
117  std::vector<uint32_t> col_ind;
119  std::vector<size_t> row_ptr;
121  size_t num_row;
123  size_t num_col;
124 
125  CSRDMatrixImpl() = delete;
126  CSRDMatrixImpl(std::vector<ElementType> data, std::vector<uint32_t> col_ind,
127  std::vector<size_t> row_ptr, size_t num_row, size_t num_col);
128  CSRDMatrixImpl(const CSRDMatrixImpl&) = default;
129  CSRDMatrixImpl(CSRDMatrixImpl&&) noexcept = default;
130  CSRDMatrixImpl& operator=(const CSRDMatrixImpl&) = default;
131  CSRDMatrixImpl& operator=(CSRDMatrixImpl&&) noexcept = default;
132 
133  size_t GetNumRow() const override;
134  size_t GetNumCol() const override;
135  size_t GetNumElem() const override;
136  DMatrixType GetType() const override;
137 
138  template <typename OutputType>
139  void FillRow(size_t row_id, OutputType* out) const;
140  template <typename OutputType>
141  void ClearRow(size_t row_id, OutputType* out) const;
142 
143  friend class CSRDMatrix;
144 };
145 
146 } // namespace treelite
147 
148 #endif // TREELITE_DATA_H_
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:123
ElementType missing_value
value representing the missing value (usually NaN)
Definition: data.h:59
std::vector< size_t > row_ptr
pointer to row headers; length is [num_row] + 1.
Definition: data.h:119
std::vector< uint32_t > col_ind
feature indices. col_ind[i] indicates the feature index associated with data[i].
Definition: data.h:117
std::vector< ElementType > data
feature values
Definition: data.h:57
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< ElementType > data
feature values
Definition: data.h:115
Defines TypeInfo class and utilities.
size_t num_row
number of rows
Definition: data.h:61
size_t num_row
number of rows
Definition: data.h:121
size_t num_col
number of columns (i.e. # of features used)
Definition: data.h:63