Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <algorithm>
12 #include <map>
13 #include <string>
14 #include <vector>
15 #include <utility>
16 #include <type_traits>
17 #include <limits>
18 #include <cstring>
19 #include <cstdio>
20 
21 #define __TREELITE_STR(x) #x
22 #define _TREELITE_STR(x) __TREELITE_STR(x)
23 
24 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
25 
26 /* Foward declarations */
27 namespace dmlc {
28 
29 class Stream;
30 float stof(const std::string& value, size_t* pos);
31 
32 } // namespace dmlc
33 
34 namespace treelite {
35 
36 struct PyBufferFrame {
37  void* buf;
38  char* format;
39  size_t itemsize;
40  size_t nitem;
41 };
42 
43 template <typename T>
45  public:
47  ~ContiguousArray();
48  // NOTE: use Clone to make deep copy; copy constructors disabled
49  ContiguousArray(const ContiguousArray&) = delete;
50  ContiguousArray& operator=(const ContiguousArray&) = delete;
51  ContiguousArray(ContiguousArray&& other) noexcept;
52  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
53  inline ContiguousArray Clone() const;
54  inline void UseForeignBuffer(void* prealloc_buf, size_t size);
55  inline T* Data();
56  inline const T* Data() const;
57  inline T* End();
58  inline const T* End() const;
59  inline T& Back();
60  inline const T& Back() const;
61  inline size_t Size() const;
62  inline void Reserve(size_t newsize);
63  inline void Resize(size_t newsize);
64  inline void Resize(size_t newsize, T t);
65  inline void Clear();
66  inline void PushBack(T t);
67  inline void Extend(const std::vector<T>& other);
68  inline T& operator[](size_t idx);
69  inline const T& operator[](size_t idx) const;
70  static_assert(std::is_pod<T>::value, "T must be POD");
71 
72  private:
73  T* buffer_;
74  size_t size_;
75  size_t capacity_;
76  bool owned_buffer_;
77 };
78 
80 class Tree {
81  public:
83  struct Node {
86  inline void Init();
88  union Info {
89  tl_float leaf_value; // for leaf nodes
90  tl_float threshold; // for non-leaf nodes
91  };
93  int32_t cleft_, cright_;
98  uint32_t sindex_;
105  uint64_t data_count_;
112  double sum_hess_;
116  double gain_;
125  /* \brief Whether to convert missing value to zero.
126  * Only applicable when split_type_ is set to kCategorical.
127  * When this flag is set, it overrides the behavior of default_left().
128  */
129  bool missing_category_to_zero_;
136  // padding
137  uint16_t pad_;
138  };
139 
140  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
141  static_assert(sizeof(Node) == 48, "Node must be 48 bytes");
142 
143  Tree() = default;
144  ~Tree() = default;
145  Tree(const Tree&) = delete;
146  Tree& operator=(const Tree&) = delete;
147  Tree(Tree&&) = default;
148  Tree& operator=(Tree&&) = default;
149  inline Tree Clone() const;
150 
151  inline std::vector<PyBufferFrame> GetPyBuffer();
152  inline void InitFromPyBuffer(std::vector<PyBufferFrame> frames);
153 
154  private:
155  // vector of nodes
156  ContiguousArray<Node> nodes_;
157  ContiguousArray<tl_float> leaf_vector_;
158  ContiguousArray<size_t> leaf_vector_offset_;
159  ContiguousArray<uint32_t> left_categories_;
160  ContiguousArray<size_t> left_categories_offset_;
161 
162  // allocate a new node
163  inline int AllocNode();
164 
165  public:
169  inline void Init();
174  inline void AddChilds(int nid);
175 
180  inline std::vector<unsigned> GetCategoricalFeatures() const;
181 
187  inline int LeftChild(int nid) const;
192  inline int RightChild(int nid) const;
197  inline int DefaultChild(int nid) const;
202  inline uint32_t SplitIndex(int nid) const;
207  inline bool DefaultLeft(int nid) const;
212  inline bool IsLeaf(int nid) const;
217  inline tl_float LeafValue(int nid) const;
222  inline std::vector<tl_float> LeafVector(int nid) const;
227  inline bool HasLeafVector(int nid) const;
232  inline tl_float Threshold(int nid) const;
237  inline Operator ComparisonOp(int nid) const;
245  inline std::vector<uint32_t> LeftCategories(int nid) const;
250  inline SplitFeatureType SplitType(int nid) const;
255  inline bool HasDataCount(int nid) const;
260  inline uint64_t DataCount(int nid) const;
265  inline bool HasSumHess(int nid) const;
270  inline double SumHess(int nid) const;
275  inline bool HasGain(int nid) const;
280  inline double Gain(int nid) const;
286  inline bool MissingCategoryToZero(int nid) const;
287 
298  inline void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold,
299  bool default_left, Operator cmp);
309  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
310  bool missing_category_to_zero,
311  const std::vector<uint32_t>& left_categories);
317  inline void SetLeaf(int nid, tl_float value);
323  inline void SetLeafVector(int nid, const std::vector<tl_float>& leaf_vector);
329  inline void SetSumHess(int nid, double sum_hess);
335  inline void SetDataCount(int nid, uint64_t data_count);
341  inline void SetGain(int nid, double gain);
342 
343  void ReferenceSerialize(dmlc::Stream* fo) const;
344 };
345 
346 struct ModelParam {
368  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
383  float global_bias;
386  ModelParam() : sigmoid_alpha(1.0f), global_bias(0.0f) {
387  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
388  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
389  }
390  ~ModelParam() = default;
391  ModelParam(const ModelParam&) = default;
392  ModelParam& operator=(const ModelParam&) = default;
393  ModelParam(ModelParam&&) = default;
394  ModelParam& operator=(ModelParam&&) = default;
395 
396  template<typename Container>
397  inline std::vector<std::pair<std::string, std::string>>
398  InitAllowUnknown(const Container &kwargs);
399  inline std::map<std::string, std::string> __DICT__() const;
400 };
401 
402 static_assert(std::is_standard_layout<ModelParam>::value,
403  "ModelParam must be in the standard layout");
404 
405 inline void InitParamAndCheck(ModelParam* param,
406  const std::vector<std::pair<std::string, std::string>>& cfg);
407 
409 struct Model {
411  std::vector<Tree> trees;
425 
427  Model() = default;
428  ~Model() = default;
429  Model(const Model&) = delete;
430  Model& operator=(const Model&) = delete;
431  Model(Model&&) = default;
432  Model& operator=(Model&&) = default;
433 
434  void ReferenceSerialize(dmlc::Stream* fo) const;
435 
436  inline std::vector<PyBufferFrame> GetPyBuffer();
437  inline void InitFromPyBuffer(std::vector<PyBufferFrame> frames);
438  inline Model Clone() const;
439 };
440 
441 } // namespace treelite
442 
443 #include "tree_impl.h"
444 
445 #endif // TREELITE_TREE_H_
SplitFeatureType split_type_
feature split type
Definition: tree.h:118
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:135
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:419
SplitFeatureType
feature split type
Definition: base.h:20
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:105
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:124
std::vector< Tree > trees
member trees
Definition: tree.h:411
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:376
ModelParam param
extra parameters
Definition: tree.h:424
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:131
tree node
Definition: tree.h:83
int32_t cleft_
pointer to left and right children
Definition: tree.h:93
in-memory representation of a decision tree
Definition: tree.h:80
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:112
float global_bias
global bias of the model
Definition: tree.h:383
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:116
store either leaf value or decision threshold
Definition: tree.h:88
Definition: tree.h:27
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:422
int num_nodes
number of nodes
Definition: tree.h:167
defines configuration macros of Treelite
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:133
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:98
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:416
Info info_
storage for leaf value or decision threshold
Definition: tree.h:100
Operator
comparison operators
Definition: base.h:24