treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/common.h>
12 #include <dmlc/logging.h>
13 #include <vector>
14 #include <limits>
15 
16 namespace treelite {
17 
19 class Tree {
20  public:
22  class Node {
23  public:
24  Node() : sindex_(0) {}
26  inline int cleft() const {
27  return this->cleft_;
28  }
30  inline int cright() const {
31  return this->cright_;
32  }
34  inline int cdefault() const {
35  return this->default_left() ? this->cleft() : this->cright();
36  }
38  inline unsigned split_index() const {
39  return sindex_ & ((1U << 31) - 1U);
40  }
42  inline bool default_left() const {
43  return (sindex_ >> 31) != 0;
44  }
46  inline bool is_leaf() const {
47  return cleft_ == -1;
48  }
50  inline tl_float leaf_value() const {
51  return (this->info_).leaf_value;
52  }
57  inline const std::vector<tl_float>& leaf_vector() const {
58  return this->leaf_vector_;
59  }
63  inline bool has_leaf_vector() const {
64  return !(this->leaf_vector_.empty());
65  }
67  inline tl_float threshold() const {
68  return (this->info_).threshold;
69  }
71  inline int parent() const {
72  return parent_ & ((1U << 31) - 1);
73  }
75  inline bool is_left_child() const {
76  return (parent_ & (1U << 31)) != 0;
77  }
79  inline bool is_root() const {
80  return parent_ == -1;
81  }
83  inline Operator comparison_op() const {
84  return cmp_;
85  }
87  inline const std::vector<uint32_t>& left_categories() const {
88  return left_categories_;
89  }
91  inline SplitFeatureType split_type() const {
92  return split_type_;
93  }
103  bool default_left, Operator cmp) {
104  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
105  if (default_left) split_index |= (1U << 31);
106  this->sindex_ = split_index;
107  (this->info_).threshold = threshold;
108  this->cmp_ = cmp;
109  this->split_type_ = SplitFeatureType::kNumerical;
110  }
119  inline void set_categorical_split(unsigned split_index, bool default_left,
120  const std::vector<uint32_t>& left_categories) {
121  CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
122  if (default_left) split_index |= (1U << 31);
123  this->sindex_ = split_index;
124  this->left_categories_ = left_categories;
125  std::sort(this->left_categories_.begin(), this->left_categories_.end());
126  this->split_type_ = SplitFeatureType::kCategorical;
127  }
132  inline void set_leaf(tl_float value) {
133  (this->info_).leaf_value = value;
134  this->cleft_ = -1;
135  this->cright_ = -1;
136  this->split_type_ = SplitFeatureType::kNone;
137  }
143  inline void set_leaf_vector(const std::vector<tl_float>& leaf_vector) {
144  this->leaf_vector_ = leaf_vector;
145  this->cleft_ = -1;
146  this->cright_ = -1;
147  this->split_type_ = SplitFeatureType::kNone;
148  }
154  inline void set_parent(int pidx, bool is_left_child = true) {
155  if (is_left_child) pidx |= (1U << 31);
156  this->parent_ = pidx;
157  }
158 
159  private:
160  friend class Tree;
162  union Info {
163  tl_float leaf_value; // for leaf nodes
164  tl_float threshold; // for non-leaf nodes
165  };
170  std::vector<tl_float> leaf_vector_;
175  int parent_;
177  int cleft_, cright_;
179  SplitFeatureType split_type_;
184  unsigned sindex_;
186  Info info_;
192  Operator cmp_;
200  std::vector<uint32_t> left_categories_;
201  };
202 
203  private:
204  // vector of nodes
205  std::vector<Node> nodes;
206  // allocate a new node
207  inline int AllocNode() {
208  int nd = num_nodes++;
209  CHECK_LT(num_nodes, std::numeric_limits<int>::max())
210  << "number of nodes in the tree exceed 2^31";
211  nodes.resize(num_nodes);
212  return nd;
213  }
214 
215  public:
223  inline Node& operator[](int nid) {
224  return nodes[nid];
225  }
231  inline const Node& operator[](int nid) const {
232  return nodes[nid];
233  }
235  inline void Init() {
236  num_nodes = 1;
237  nodes.resize(1);
238  nodes[0].set_leaf(0.0f);
239  nodes[0].set_parent(-1);
240  }
245  inline void AddChilds(int nid) {
246  const int cleft = this->AllocNode();
247  const int cright = this->AllocNode();
248  nodes[nid].cleft_ = cleft;
249  nodes[nid].cright_ = cright;
250  nodes[cleft].set_parent(nid, true);
251  nodes[cright].set_parent(nid, false);
252  }
253 
258  inline std::vector<unsigned> GetCategoricalFeatures() const {
259  std::unordered_map<unsigned, bool> tmp;
260  for (int nid = 0; nid < num_nodes; ++nid) {
261  const Node& node = nodes[nid];
262  const SplitFeatureType type = node.split_type();
263  if (type != SplitFeatureType::kNone) {
264  const bool flag = (type == SplitFeatureType::kCategorical);
265  const unsigned split_index = node.split_index();
266  if (tmp.count(split_index) == 0) {
267  tmp[split_index] = flag;
268  } else {
269  CHECK_EQ(tmp[split_index], flag) << "Feature " << split_index
270  << " cannot be simultaneously be categorical and numerical.";
271  }
272  }
273  }
274  std::vector<unsigned> result;
275  for (const auto& kv : tmp) {
276  if (kv.second) {
277  result.push_back(kv.first);
278  }
279  }
280  std::sort(result.begin(), result.end());
281  return result;
282  }
283 };
284 
285 struct ModelParam : public dmlc::Parameter<ModelParam> {
307  std::string pred_transform;
322  float global_bias;
325  // declare parameters
326  DMLC_DECLARE_PARAMETER(ModelParam) {
327  DMLC_DECLARE_FIELD(pred_transform).set_default("identity")
328  .describe("name of prediction transform function");
329  DMLC_DECLARE_FIELD(sigmoid_alpha).set_default(1.0f)
330  .set_lower_bound(0.0f)
331  .describe("scaling parameter for sigmoid function");
332  DMLC_DECLARE_FIELD(global_bias).set_default(0.0f)
333  .describe("global bias of the model");
334  }
335 };
336 
337 inline void InitParamAndCheck(ModelParam* param,
338  const std::vector<std::pair<std::string, std::string>> cfg) {
339  auto unknown = param->InitAllowUnknown(cfg);
340  if (unknown.size() > 0) {
341  std::ostringstream oss;
342  for (const auto& kv : unknown) {
343  oss << kv.first << ", ";
344  }
345  LOG(INFO) << "\033[1;31mWarning: Unknown parameters found; "
346  << "they have been ignored\u001B[0m: " << oss.str();
347  }
348 }
349 
351 struct Model {
353  std::vector<Tree> trees;
367 
369  Model() {
370  param.Init(std::vector<std::pair<std::string, std::string>>());
371  }
372  Model(const Model&) = delete;
373  Model& operator=(const Model&) = delete;
374  Model(Model&&) = default;
375  Model& operator=(Model&&) = default;
376 };
377 
378 } // namespace treelite
379 #endif // TREELITE_TREE_H_
bool is_left_child() const
whether current node is left child
Definition: tree.h:75
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
SplitFeatureType
feature split type
Definition: base.h:19
void Init()
initialize the model with a single root node
Definition: tree.h:235
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
void set_leaf(tl_float value)
set the leaf value of the node
Definition: tree.h:132
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:353
void set_parent(int pidx, bool is_left_child=true)
set parent of the node
Definition: tree.h:154
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:315
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
ModelParam param
extra parameters
Definition: tree.h:366
Operator comparison_op() const
get comparison operator
Definition: tree.h:83
void set_categorical_split(unsigned split_index, bool default_left, const std::vector< uint32_t > &left_categories)
create a categorical split
Definition: tree.h:119
in-memory representation of a decision tree
Definition: tree.h:19
const std::vector< uint32_t > & left_categories() const
get categories for left child node
Definition: tree.h:87
float global_bias
global bias of the model
Definition: tree.h:322
tl_float threshold() const
Definition: tree.h:67
bool has_leaf_vector() const
Definition: tree.h:63
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:57
int cright() const
index of right child
Definition: tree.h:30
std::string pred_transform
name of prediction transform function
Definition: tree.h:307
int cdefault() const
index of default child when feature is missing
Definition: tree.h:34
void set_numerical_split(unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
create a numerical split
Definition: tree.h:102
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:364
void set_leaf_vector(const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree.h:143
int num_nodes
number of nodes
Definition: tree.h:217
const Node & operator[](int nid) const
get node given nid (const version)
Definition: tree.h:231
defines configuration macros of treelite
tl_float leaf_value() const
Definition: tree.h:50
bool default_left() const
when feature is unknown, whether goes to left child
Definition: tree.h:42
Node & operator[](int nid)
get node given nid
Definition: tree.h:223
bool is_root() const
whether current node is root
Definition: tree.h:79
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:245
Some useful utilities.
int cleft() const
index of left child
Definition: tree.h:26
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:258
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
int parent() const
get parent of the node
Definition: tree.h:71
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:358
Model()
disable copy; use default move
Definition: tree.h:369
Operator
comparison operators
Definition: base.h:23