treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/common.h>
12 #include <dmlc/logging.h>
13 #include <vector>
14 #include <limits>
15 
16 namespace treelite {
17 
19 class Tree {
20  public:
22  class Node {
23  public:
24  Node() : sindex_(0) {}
26  inline int cleft() const {
27  return this->cleft_;
28  }
30  inline int cright() const {
31  return this->cright_;
32  }
34  inline int cdefault() const {
35  return this->default_left() ? this->cleft() : this->cright();
36  }
38  inline unsigned split_index() const {
39  return sindex_ & ((1U << 31) - 1U);
40  }
42  inline bool default_left() const {
43  return (sindex_ >> 31) != 0;
44  }
46  inline bool is_leaf() const {
47  return cleft_ == -1;
48  }
50  inline tl_float leaf_value() const {
51  return (this->info_).leaf_value;
52  }
57  inline const std::vector<tl_float>& leaf_vector() const {
58  return this->leaf_vector_;
59  }
63  inline bool has_leaf_vector() const {
64  return !(this->leaf_vector_.empty());
65  }
67  inline tl_float threshold() const {
68  return (this->info_).threshold;
69  }
71  inline int parent() const {
72  return parent_ & ((1U << 31) - 1);
73  }
75  inline bool is_left_child() const {
76  return (parent_ & (1U << 31)) != 0;
77  }
79  inline bool is_root() const {
80  return parent_ == -1;
81  }
83  inline Operator comparison_op() const {
84  return cmp_;
85  }
87  inline const std::vector<uint8_t>& left_categories() const {
88  return left_categories_;
89  }
91  inline SplitFeatureType split_type() const {
92  return split_type_;
93  }
103  bool default_left, Operator cmp) {
104  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
105  if (default_left) split_index |= (1U << 31);
106  this->sindex_ = split_index;
107  (this->info_).threshold = threshold;
108  this->cmp_ = cmp;
109  this->split_type_ = SplitFeatureType::kNumerical;
110  }
119  inline void set_categorical_split(unsigned split_index, bool default_left,
120  const std::vector<uint8_t>& left_categories) {
121  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
122  if (default_left) split_index |= (1U << 31);
123  this->sindex_ = split_index;
124  this->left_categories_ = left_categories;
125  std::sort(this->left_categories_.begin(), this->left_categories_.end());
126  this->split_type_ = SplitFeatureType::kCategorical;
127  }
132  inline void set_leaf(tl_float value) {
133  (this->info_).leaf_value = value;
134  this->cleft_ = -1;
135  this->cright_ = -1;
136  this->split_type_ = SplitFeatureType::kNone;
137  }
143  inline void set_leaf_vector(const std::vector<tl_float>& leaf_vector) {
144  this->leaf_vector_ = leaf_vector;
145  this->cleft_ = -1;
146  this->cright_ = -1;
147  this->split_type_ = SplitFeatureType::kNone;
148  }
154  inline void set_parent(int pidx, bool is_left_child = true) {
155  if (is_left_child) pidx |= (1U << 31);
156  this->parent_ = pidx;
157  }
158 
159  private:
160  friend class Tree;
162  union Info {
163  tl_float leaf_value; // for leaf nodes
164  tl_float threshold; // for non-leaf nodes
165  };
170  std::vector<tl_float> leaf_vector_;
175  int parent_;
177  int cleft_, cright_;
179  SplitFeatureType split_type_;
184  unsigned sindex_;
186  Info info_;
192  Operator cmp_;
199  std::vector<uint8_t> left_categories_;
200  };
201 
202  private:
203  // vector of nodes
204  std::vector<Node> nodes;
205  // allocate a new node
206  inline int AllocNode() {
207  int nd = num_nodes++;
208  CHECK_LT(num_nodes, std::numeric_limits<int>::max())
209  << "number of nodes in the tree exceed 2^31";
210  nodes.resize(num_nodes);
211  return nd;
212  }
213 
214  public:
222  inline Node& operator[](int nid) {
223  return nodes[nid];
224  }
230  inline const Node& operator[](int nid) const {
231  return nodes[nid];
232  }
234  inline void Init() {
235  num_nodes = 1;
236  nodes.resize(1);
237  nodes[0].set_leaf(0.0f);
238  nodes[0].set_parent(-1);
239  }
244  inline void AddChilds(int nid) {
245  const int cleft = this->AllocNode();
246  const int cright = this->AllocNode();
247  nodes[nid].cleft_ = cleft;
248  nodes[nid].cright_ = cright;
249  nodes[cleft].set_parent(nid, true);
250  nodes[cright].set_parent(nid, false);
251  }
252 
257  inline std::vector<unsigned> GetCategoricalFeatures() const {
258  std::unordered_map<unsigned, bool> tmp;
259  for (int nid = 0; nid < num_nodes; ++nid) {
260  const Node& node = nodes[nid];
261  const SplitFeatureType type = node.split_type();
262  if (type != SplitFeatureType::kNone) {
263  const bool flag = (type == SplitFeatureType::kCategorical);
264  const unsigned split_index = node.split_index();
265  if (tmp.count(split_index) == 0) {
266  tmp[split_index] = flag;
267  } else {
268  CHECK_EQ(tmp[split_index], flag) << "Feature " << split_index
269  << " cannot be simultaneously be categorical and numerical.";
270  }
271  }
272  }
273  std::vector<unsigned> result;
274  for (const auto& kv : tmp) {
275  if (kv.second) {
276  result.push_back(kv.first);
277  }
278  }
279  std::sort(result.begin(), result.end());
280  return result;
281  }
282 };
283 
284 struct ModelParam : public dmlc::Parameter<ModelParam> {
306  std::string pred_transform;
321  float global_bias;
324  // declare parameters
325  DMLC_DECLARE_PARAMETER(ModelParam) {
326  DMLC_DECLARE_FIELD(pred_transform).set_default("identity")
327  .describe("name of prediction transform function");
328  DMLC_DECLARE_FIELD(sigmoid_alpha).set_default(1.0f)
329  .set_lower_bound(0.0f)
330  .describe("scaling parameter for sigmoid function");
331  DMLC_DECLARE_FIELD(global_bias).set_default(0.0f)
332  .describe("global bias of the model");
333  }
334 };
335 
336 inline void InitParamAndCheck(ModelParam* param,
337  const std::vector<std::pair<std::string, std::string>> cfg) {
338  auto unknown = param->InitAllowUnknown(cfg);
339  if (unknown.size() > 0) {
340  std::ostringstream oss;
341  for (const auto& kv : unknown) {
342  oss << kv.first << ", ";
343  }
344  LOG(INFO) << "\033[1;31mWarning: Unknown parameters found; "
345  << "they have been ignored\u001B[0m: " << oss.str();
346  }
347 }
348 
350 struct Model {
352  std::vector<Tree> trees;
366 
368  Model() = default;
369  Model(const Model&) = delete;
370  Model& operator=(const Model&) = delete;
371  Model(Model&&) = default;
372  Model& operator=(Model&&) = default;
373 };
374 
375 } // namespace treelite
376 #endif // TREELITE_TREE_H_
bool is_left_child() const
whether current node is left child
Definition: tree.h:75
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:360
SplitFeatureType
feature split type
Definition: base.h:19
void Init()
initialize the model with a single root node
Definition: tree.h:234
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
void set_leaf(tl_float value)
set the leaf value of the node
Definition: tree.h:132
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:352
void set_parent(int pidx, bool is_left_child=true)
set parent of the node
Definition: tree.h:154
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:314
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
ModelParam param
extra parameters
Definition: tree.h:365
Operator comparison_op() const
get comparison operator
Definition: tree.h:83
in-memory representation of a decision tree
Definition: tree.h:19
float global_bias
global bias of the model
Definition: tree.h:321
tl_float threshold() const
Definition: tree.h:67
bool has_leaf_vector() const
Definition: tree.h:63
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:57
int cright() const
index of right child
Definition: tree.h:30
std::string pred_transform
name of prediction transform function
Definition: tree.h:306
int cdefault() const
index of default child when feature is missing
Definition: tree.h:34
const std::vector< uint8_t > & left_categories() const
get categories for left child node
Definition: tree.h:87
void set_numerical_split(unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
create a numerical split
Definition: tree.h:102
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:363
void set_categorical_split(unsigned split_index, bool default_left, const std::vector< uint8_t > &left_categories)
create a categorical split
Definition: tree.h:119
void set_leaf_vector(const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree.h:143
int num_nodes
number of nodes
Definition: tree.h:216
const Node & operator[](int nid) const
get node given nid (const version)
Definition: tree.h:230
defines configuration macros of treelite
tl_float leaf_value() const
Definition: tree.h:50
bool default_left() const
when feature is unknown, whether goes to left child
Definition: tree.h:42
Node & operator[](int nid)
get node given nid
Definition: tree.h:222
bool is_root() const
whether current node is root
Definition: tree.h:79
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:244
Some useful utilities.
int cleft() const
index of left child
Definition: tree.h:26
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:257
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
int parent() const
get parent of the node
Definition: tree.h:71
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:357
Operator
comparison operators
Definition: base.h:23