treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/common.h>
12 #include <dmlc/logging.h>
13 #include <algorithm>
14 #include <vector>
15 #include <utility>
16 #include <unordered_map>
17 #include <string>
18 #include <limits>
19 
20 namespace treelite {
21 
23 class Tree {
24  public:
26  class Node {
27  public:
28  Node() : sindex_(0), missing_category_to_zero_(false) {}
30  inline int cleft() const {
31  return this->cleft_;
32  }
34  inline int cright() const {
35  return this->cright_;
36  }
38  inline int cdefault() const {
39  return this->default_left() ? this->cleft() : this->cright();
40  }
42  inline unsigned split_index() const {
43  return sindex_ & ((1U << 31) - 1U);
44  }
46  inline bool default_left() const {
47  return (sindex_ >> 31) != 0;
48  }
50  inline bool is_leaf() const {
51  return cleft_ == -1;
52  }
54  inline tl_float leaf_value() const {
55  return (this->info_).leaf_value;
56  }
61  inline const std::vector<tl_float>& leaf_vector() const {
62  return this->leaf_vector_;
63  }
67  inline bool has_leaf_vector() const {
68  return !(this->leaf_vector_.empty());
69  }
71  inline tl_float threshold() const {
72  return (this->info_).threshold;
73  }
75  inline int parent() const {
76  return parent_ & ((1U << 31) - 1);
77  }
79  inline bool is_left_child() const {
80  return (parent_ & (1U << 31)) != 0;
81  }
83  inline bool is_root() const {
84  return parent_ == -1;
85  }
87  inline Operator comparison_op() const {
88  return cmp_;
89  }
91  inline const std::vector<uint32_t>& left_categories() const {
92  return left_categories_;
93  }
95  inline SplitFeatureType split_type() const {
96  return split_type_;
97  }
99  inline bool has_data_count() const {
100  return data_count_.has_value();
101  }
103  inline size_t data_count() const {
104  return data_count_.value();
105  }
107  inline bool has_sum_hess() const {
108  return sum_hess_.has_value();
109  }
111  inline double sum_hess() const {
112  return sum_hess_.value();
113  }
115  inline bool has_gain() const {
116  return gain_.has_value();
117  }
119  inline double gain() const {
120  return gain_.value();
121  }
124  inline bool missing_category_to_zero() const {
125  return missing_category_to_zero_;
126  }
136  bool default_left, Operator cmp) {
137  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
138  if (default_left) split_index |= (1U << 31);
139  this->sindex_ = split_index;
140  (this->info_).threshold = threshold;
141  this->cmp_ = cmp;
142  this->split_type_ = SplitFeatureType::kNumerical;
143  }
152  inline void set_categorical_split(unsigned split_index, bool default_left,
154  const std::vector<uint32_t>& left_categories) {
155  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
156  if (default_left) split_index |= (1U << 31);
157  this->sindex_ = split_index;
158  this->left_categories_ = left_categories;
159  std::sort(this->left_categories_.begin(), this->left_categories_.end());
160  this->split_type_ = SplitFeatureType::kCategorical;
161  this->missing_category_to_zero_ = missing_category_to_zero;
162  }
167  inline void set_leaf(tl_float value) {
168  (this->info_).leaf_value = value;
169  this->cleft_ = -1;
170  this->cright_ = -1;
171  this->split_type_ = SplitFeatureType::kNone;
172  }
178  inline void set_leaf_vector(const std::vector<tl_float>& leaf_vector) {
179  this->leaf_vector_ = leaf_vector;
180  this->cleft_ = -1;
181  this->cright_ = -1;
182  this->split_type_ = SplitFeatureType::kNone;
183  }
188  inline void set_sum_hess(double sum_hess) {
189  this->sum_hess_ = sum_hess;
190  }
195  inline void set_data_count(size_t data_count) {
196  this->data_count_ = data_count;
197  }
202  inline void set_gain(double gain) {
203  this->gain_ = gain;
204  }
210  inline void set_parent(int pidx, bool is_left_child = true) {
211  if (is_left_child) pidx |= (1U << 31);
212  this->parent_ = pidx;
213  }
214 
215  private:
216  friend class Tree;
218  union Info {
219  tl_float leaf_value; // for leaf nodes
220  tl_float threshold; // for non-leaf nodes
221  };
226  std::vector<tl_float> leaf_vector_;
231  int parent_;
233  int cleft_, cright_;
235  SplitFeatureType split_type_;
240  unsigned sindex_;
242  Info info_;
248  Operator cmp_;
256  std::vector<uint32_t> left_categories_;
257  /* \brief Whether to convert missing value to zero.
258  * Only applicable when split_type_ is set to kCategorical.
259  * When this flag is set, it overrides the behavior of default_left().
260  */
261  bool missing_category_to_zero_;
266  dmlc::optional<size_t> data_count_;
273  dmlc::optional<double> sum_hess_;
277  dmlc::optional<double> gain_;
278  };
279 
280  private:
281  // vector of nodes
282  std::vector<Node> nodes;
283  // allocate a new node
284  inline int AllocNode() {
285  int nd = num_nodes++;
286  CHECK_LT(num_nodes, std::numeric_limits<int>::max())
287  << "number of nodes in the tree exceed 2^31";
288  nodes.resize(num_nodes);
289  return nd;
290  }
291 
292  public:
300  inline Node& operator[](int nid) {
301  return nodes[nid];
302  }
308  inline const Node& operator[](int nid) const {
309  return nodes[nid];
310  }
312  inline void Init() {
313  num_nodes = 1;
314  nodes.resize(1);
315  nodes[0].set_leaf(0.0f);
316  nodes[0].set_parent(-1);
317  }
322  inline void AddChilds(int nid) {
323  const int cleft = this->AllocNode();
324  const int cright = this->AllocNode();
325  nodes[nid].cleft_ = cleft;
326  nodes[nid].cright_ = cright;
327  nodes[cleft].set_parent(nid, true);
328  nodes[cright].set_parent(nid, false);
329  }
330 
335  inline std::vector<unsigned> GetCategoricalFeatures() const {
336  std::unordered_map<unsigned, bool> tmp;
337  for (int nid = 0; nid < num_nodes; ++nid) {
338  const Node& node = nodes[nid];
339  const SplitFeatureType type = node.split_type();
340  if (type != SplitFeatureType::kNone) {
341  const bool flag = (type == SplitFeatureType::kCategorical);
342  const unsigned split_index = node.split_index();
343  if (tmp.count(split_index) == 0) {
344  tmp[split_index] = flag;
345  } else {
346  CHECK_EQ(tmp[split_index], flag) << "Feature " << split_index
347  << " cannot be simultaneously be categorical and numerical.";
348  }
349  }
350  }
351  std::vector<unsigned> result;
352  for (const auto& kv : tmp) {
353  if (kv.second) {
354  result.push_back(kv.first);
355  }
356  }
357  std::sort(result.begin(), result.end());
358  return result;
359  }
360 };
361 
362 struct ModelParam : public dmlc::Parameter<ModelParam> {
384  std::string pred_transform;
399  float global_bias;
402  // declare parameters
403  DMLC_DECLARE_PARAMETER(ModelParam) {
404  DMLC_DECLARE_FIELD(pred_transform).set_default("identity")
405  .describe("name of prediction transform function");
406  DMLC_DECLARE_FIELD(sigmoid_alpha).set_default(1.0f)
407  .set_lower_bound(0.0f)
408  .describe("scaling parameter for sigmoid function");
409  DMLC_DECLARE_FIELD(global_bias).set_default(0.0f)
410  .describe("global bias of the model");
411  }
412 };
413 
414 inline void InitParamAndCheck(ModelParam* param,
415  const std::vector<std::pair<std::string, std::string>> cfg) {
416  auto unknown = param->InitAllowUnknown(cfg);
417  if (unknown.size() > 0) {
418  std::ostringstream oss;
419  for (const auto& kv : unknown) {
420  oss << kv.first << ", ";
421  }
422  LOG(INFO) << "\033[1;31mWarning: Unknown parameters found; "
423  << "they have been ignored\u001B[0m: " << oss.str();
424  }
425 }
426 
428 struct Model {
430  std::vector<Tree> trees;
444 
446  Model() {
447  param.Init(std::vector<std::pair<std::string, std::string>>());
448  }
449  Model(const Model&) = delete;
450  Model& operator=(const Model&) = delete;
451  Model(Model&&) = default;
452  Model& operator=(Model&&) = default;
453 };
454 
455 } // namespace treelite
456 #endif // TREELITE_TREE_H_
int cright() const
index of right child
Definition: tree.h:34
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:438
SplitFeatureType
feature split type
Definition: base.h:19
void set_data_count(size_t data_count)
set the data count of the node
Definition: tree.h:195
void Init()
initialize the model with a single root node
Definition: tree.h:312
thin wrapper for tree ensemble model
Definition: tree.h:428
void set_leaf(tl_float value)
set the leaf value of the node
Definition: tree.h:167
tree node
Definition: tree.h:26
std::vector< Tree > trees
member trees
Definition: tree.h:430
void set_parent(int pidx, bool is_left_child=true)
set parent of the node
Definition: tree.h:210
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:392
int cleft() const
index of left child
Definition: tree.h:30
void set_gain(double gain)
set the gain value of the node
Definition: tree.h:202
ModelParam param
extra parameters
Definition: tree.h:443
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:61
in-memory representation of a decision tree
Definition: tree.h:23
void set_categorical_split(unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
Definition: tree.h:152
tl_float leaf_value() const
Definition: tree.h:54
float global_bias
global bias of the model
Definition: tree.h:399
const std::vector< uint32_t > & left_categories() const
get categories for left child node
Definition: tree.h:91
bool has_sum_hess() const
test whether this node has hessian sum
Definition: tree.h:107
Operator comparison_op() const
get comparison operator
Definition: tree.h:87
size_t data_count() const
get data count
Definition: tree.h:103
bool has_leaf_vector() const
Definition: tree.h:67
unsigned split_index() const
feature index of split condition
Definition: tree.h:42
bool default_left() const
when feature is unknown, whether goes to left child
Definition: tree.h:46
std::string pred_transform
name of prediction transform function
Definition: tree.h:384
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:50
void set_numerical_split(unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
create a numerical split
Definition: tree.h:135
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:441
void set_leaf_vector(const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree.h:178
int num_nodes
number of nodes
Definition: tree.h:294
double tl_float
float type to be used internally
Definition: base.h:17
tl_float threshold() const
Definition: tree.h:71
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:335
bool has_data_count() const
test whether this node has data count
Definition: tree.h:99
double gain() const
get gain value
Definition: tree.h:119
defines configuration macros of treelite
Node & operator[](int nid)
get node given nid
Definition: tree.h:300
bool missing_category_to_zero() const
test whether missing values should be converted into zero; only applicable for categorical splits ...
Definition: tree.h:124
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:322
bool is_root() const
whether current node is root
Definition: tree.h:83
void set_sum_hess(double sum_hess)
set the hessian sum of the node
Definition: tree.h:188
bool has_gain() const
test whether this node has gain value
Definition: tree.h:115
int cdefault() const
index of default child when feature is missing
Definition: tree.h:38
bool is_left_child() const
whether current node is left child
Definition: tree.h:79
double sum_hess() const
get hessian sum
Definition: tree.h:111
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:95
const Node & operator[](int nid) const
get node given nid (const version)
Definition: tree.h:308
int parent() const
get parent of the node
Definition: tree.h:75
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:435
Model()
disable copy; use default move
Definition: tree.h:446
Operator
comparison operators
Definition: base.h:23