treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/common.h>
12 #include <dmlc/logging.h>
13 #include <algorithm>
14 #include <vector>
15 #include <utility>
16 #include <string>
17 #include <limits>
18 
19 namespace treelite {
20 
22 class Tree {
23  public:
25  class Node {
26  public:
27  Node() : sindex_(0) {}
29  inline int cleft() const {
30  return this->cleft_;
31  }
33  inline int cright() const {
34  return this->cright_;
35  }
37  inline int cdefault() const {
38  return this->default_left() ? this->cleft() : this->cright();
39  }
41  inline unsigned split_index() const {
42  return sindex_ & ((1U << 31) - 1U);
43  }
45  inline bool default_left() const {
46  return (sindex_ >> 31) != 0;
47  }
49  inline bool is_leaf() const {
50  return cleft_ == -1;
51  }
53  inline tl_float leaf_value() const {
54  return (this->info_).leaf_value;
55  }
60  inline const std::vector<tl_float>& leaf_vector() const {
61  return this->leaf_vector_;
62  }
66  inline bool has_leaf_vector() const {
67  return !(this->leaf_vector_.empty());
68  }
70  inline tl_float threshold() const {
71  return (this->info_).threshold;
72  }
74  inline int parent() const {
75  return parent_ & ((1U << 31) - 1);
76  }
78  inline bool is_left_child() const {
79  return (parent_ & (1U << 31)) != 0;
80  }
82  inline bool is_root() const {
83  return parent_ == -1;
84  }
86  inline Operator comparison_op() const {
87  return cmp_;
88  }
90  inline const std::vector<uint32_t>& left_categories() const {
91  return left_categories_;
92  }
94  inline SplitFeatureType split_type() const {
95  return split_type_;
96  }
98  inline bool has_data_count() const {
99  return data_count_.has_value();
100  }
102  inline size_t data_count() const {
103  return data_count_.value();
104  }
106  inline bool has_sum_hess() const {
107  return sum_hess_.has_value();
108  }
110  inline double sum_hess() const {
111  return sum_hess_.value();
112  }
114  inline bool has_gain() const {
115  return gain_.has_value();
116  }
118  inline double gain() const {
119  return gain_.value();
120  }
130  bool default_left, Operator cmp) {
131  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
132  if (default_left) split_index |= (1U << 31);
133  this->sindex_ = split_index;
134  (this->info_).threshold = threshold;
135  this->cmp_ = cmp;
136  this->split_type_ = SplitFeatureType::kNumerical;
137  }
146  inline void set_categorical_split(unsigned split_index, bool default_left,
147  const std::vector<uint32_t>& left_categories) {
148  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
149  if (default_left) split_index |= (1U << 31);
150  this->sindex_ = split_index;
151  this->left_categories_ = left_categories;
152  std::sort(this->left_categories_.begin(), this->left_categories_.end());
153  this->split_type_ = SplitFeatureType::kCategorical;
154  }
159  inline void set_leaf(tl_float value) {
160  (this->info_).leaf_value = value;
161  this->cleft_ = -1;
162  this->cright_ = -1;
163  this->split_type_ = SplitFeatureType::kNone;
164  }
170  inline void set_leaf_vector(const std::vector<tl_float>& leaf_vector) {
171  this->leaf_vector_ = leaf_vector;
172  this->cleft_ = -1;
173  this->cright_ = -1;
174  this->split_type_ = SplitFeatureType::kNone;
175  }
180  inline void set_sum_hess(double sum_hess) {
181  this->sum_hess_ = sum_hess;
182  }
187  inline void set_data_count(size_t data_count) {
188  this->data_count_ = data_count;
189  }
194  inline void set_gain(double gain) {
195  this->gain_ = gain;
196  }
202  inline void set_parent(int pidx, bool is_left_child = true) {
203  if (is_left_child) pidx |= (1U << 31);
204  this->parent_ = pidx;
205  }
206 
207  private:
208  friend class Tree;
210  union Info {
211  tl_float leaf_value; // for leaf nodes
212  tl_float threshold; // for non-leaf nodes
213  };
218  std::vector<tl_float> leaf_vector_;
223  int parent_;
225  int cleft_, cright_;
227  SplitFeatureType split_type_;
232  unsigned sindex_;
234  Info info_;
240  Operator cmp_;
248  std::vector<uint32_t> left_categories_;
253  dmlc::optional<size_t> data_count_;
260  dmlc::optional<double> sum_hess_;
264  dmlc::optional<double> gain_;
265  };
266 
267  private:
268  // vector of nodes
269  std::vector<Node> nodes;
270  // allocate a new node
271  inline int AllocNode() {
272  int nd = num_nodes++;
273  CHECK_LT(num_nodes, std::numeric_limits<int>::max())
274  << "number of nodes in the tree exceed 2^31";
275  nodes.resize(num_nodes);
276  return nd;
277  }
278 
279  public:
287  inline Node& operator[](int nid) {
288  return nodes[nid];
289  }
295  inline const Node& operator[](int nid) const {
296  return nodes[nid];
297  }
299  inline void Init() {
300  num_nodes = 1;
301  nodes.resize(1);
302  nodes[0].set_leaf(0.0f);
303  nodes[0].set_parent(-1);
304  }
309  inline void AddChilds(int nid) {
310  const int cleft = this->AllocNode();
311  const int cright = this->AllocNode();
312  nodes[nid].cleft_ = cleft;
313  nodes[nid].cright_ = cright;
314  nodes[cleft].set_parent(nid, true);
315  nodes[cright].set_parent(nid, false);
316  }
317 
322  inline std::vector<unsigned> GetCategoricalFeatures() const {
323  std::unordered_map<unsigned, bool> tmp;
324  for (int nid = 0; nid < num_nodes; ++nid) {
325  const Node& node = nodes[nid];
326  const SplitFeatureType type = node.split_type();
327  if (type != SplitFeatureType::kNone) {
328  const bool flag = (type == SplitFeatureType::kCategorical);
329  const unsigned split_index = node.split_index();
330  if (tmp.count(split_index) == 0) {
331  tmp[split_index] = flag;
332  } else {
333  CHECK_EQ(tmp[split_index], flag) << "Feature " << split_index
334  << " cannot be simultaneously be categorical and numerical.";
335  }
336  }
337  }
338  std::vector<unsigned> result;
339  for (const auto& kv : tmp) {
340  if (kv.second) {
341  result.push_back(kv.first);
342  }
343  }
344  std::sort(result.begin(), result.end());
345  return result;
346  }
347 };
348 
349 struct ModelParam : public dmlc::Parameter<ModelParam> {
371  std::string pred_transform;
386  float global_bias;
389  // declare parameters
390  DMLC_DECLARE_PARAMETER(ModelParam) {
391  DMLC_DECLARE_FIELD(pred_transform).set_default("identity")
392  .describe("name of prediction transform function");
393  DMLC_DECLARE_FIELD(sigmoid_alpha).set_default(1.0f)
394  .set_lower_bound(0.0f)
395  .describe("scaling parameter for sigmoid function");
396  DMLC_DECLARE_FIELD(global_bias).set_default(0.0f)
397  .describe("global bias of the model");
398  }
399 };
400 
401 inline void InitParamAndCheck(ModelParam* param,
402  const std::vector<std::pair<std::string, std::string>> cfg) {
403  auto unknown = param->InitAllowUnknown(cfg);
404  if (unknown.size() > 0) {
405  std::ostringstream oss;
406  for (const auto& kv : unknown) {
407  oss << kv.first << ", ";
408  }
409  LOG(INFO) << "\033[1;31mWarning: Unknown parameters found; "
410  << "they have been ignored\u001B[0m: " << oss.str();
411  }
412 }
413 
415 struct Model {
417  std::vector<Tree> trees;
431 
433  Model() {
434  param.Init(std::vector<std::pair<std::string, std::string>>());
435  }
436  Model(const Model&) = delete;
437  Model& operator=(const Model&) = delete;
438  Model(Model&&) = default;
439  Model& operator=(Model&&) = default;
440 };
441 
442 } // namespace treelite
443 #endif // TREELITE_TREE_H_
bool is_left_child() const
whether current node is left child
Definition: tree.h:78
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:425
SplitFeatureType
feature split type
Definition: base.h:19
void set_data_count(size_t data_count)
set the data count of the node
Definition: tree.h:187
void Init()
initialize the model with a single root node
Definition: tree.h:299
thin wrapper for tree ensemble model
Definition: tree.h:415
void set_leaf(tl_float value)
set the leaf value of the node
Definition: tree.h:159
tree node
Definition: tree.h:25
std::vector< Tree > trees
member trees
Definition: tree.h:417
void set_parent(int pidx, bool is_left_child=true)
set parent of the node
Definition: tree.h:202
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:379
unsigned split_index() const
feature index of split condition
Definition: tree.h:41
void set_gain(double gain)
set the gain value of the node
Definition: tree.h:194
ModelParam param
extra parameters
Definition: tree.h:430
double sum_hess() const
get hessian sum
Definition: tree.h:110
Operator comparison_op() const
get comparison operator
Definition: tree.h:86
void set_categorical_split(unsigned split_index, bool default_left, const std::vector< uint32_t > &left_categories)
create a categorical split
Definition: tree.h:146
in-memory representation of a decision tree
Definition: tree.h:22
const std::vector< uint32_t > & left_categories() const
get categories for left child node
Definition: tree.h:90
float global_bias
global bias of the model
Definition: tree.h:386
tl_float threshold() const
Definition: tree.h:70
bool has_leaf_vector() const
Definition: tree.h:66
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:60
int cright() const
index of right child
Definition: tree.h:33
size_t data_count() const
get data count
Definition: tree.h:102
std::string pred_transform
name of prediction transform function
Definition: tree.h:371
int cdefault() const
index of default child when feature is missing
Definition: tree.h:37
bool has_gain() const
test whether this node has gain value
Definition: tree.h:114
void set_numerical_split(unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
create a numerical split
Definition: tree.h:129
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:428
void set_leaf_vector(const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree.h:170
int num_nodes
number of nodes
Definition: tree.h:281
const Node & operator[](int nid) const
get node given nid (const version)
Definition: tree.h:295
double gain() const
get gain value
Definition: tree.h:118
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
tl_float leaf_value() const
Definition: tree.h:53
bool default_left() const
when feature is unknown, whether goes to left child
Definition: tree.h:45
Node & operator[](int nid)
get node given nid
Definition: tree.h:287
bool is_root() const
whether current node is root
Definition: tree.h:82
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:309
int cleft() const
index of left child
Definition: tree.h:29
void set_sum_hess(double sum_hess)
set the hessian sum of the node
Definition: tree.h:180
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:322
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:94
bool has_sum_hess() const
test whether this node has hessian sum
Definition: tree.h:106
int parent() const
get parent of the node
Definition: tree.h:74
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:49
bool has_data_count() const
test whether this node has data count
Definition: tree.h:98
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:422
Model()
disable copy; use default move
Definition: tree.h:433
Operator
comparison operators
Definition: base.h:23