Treelite
builder.cc
Go to the documentation of this file.
1 
8 #include <dmlc/registry.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <memory>
12 #include <queue>
13 
14 /* data structures with underscore prefixes are internal use only and
15  do not have external linkage */
16 namespace {
17 
18 struct _Node {
19  enum class _Status : int8_t {
20  kEmpty, kNumericalTest, kCategoricalTest, kLeaf
21  };
22  union _Info {
23  treelite::tl_float leaf_value; // for leaf nodes
24  treelite::tl_float threshold; // for non-leaf nodes
25  };
26  /*
27  * leaf vector: only used for random forests with multi-class classification
28  */
29  std::vector<treelite::tl_float> leaf_vector;
30  _Status status;
31  /* pointers to parent, left and right children */
32  _Node* parent;
33  _Node* left_child;
34  _Node* right_child;
35  // split feature index
36  unsigned feature_id;
37  // default direction for missing values
38  bool default_left;
39  // extra info: leaf value or threshold
40  _Info info;
41  // (for numerical split)
42  // operator to use for expression of form [fval] OP [threshold]
43  // If the expression evaluates to true, take the left child;
44  // otherwise, take the right child.
46  // (for categorical split)
47  // list of all categories that belong to the left child node.
48  // All others not in the list belong to the right child node.
49  // Categories are integers ranging from 0 to (n-1), where n is the number of
50  // categories in that particular feature. Let's assume n <= 64.
51  std::vector<uint32_t> left_categories;
52 
53  inline _Node()
54  : status(_Status::kEmpty),
55  parent(nullptr), left_child(nullptr), right_child(nullptr) {}
56 };
57 
58 struct _Tree {
59  _Node* root;
60  std::unordered_map<int, std::unique_ptr<_Node>> nodes;
61  inline _Tree() : root(nullptr), nodes() {}
62 };
63 
64 } // anonymous namespace
65 
66 namespace treelite {
67 namespace frontend {
68 
69 DMLC_REGISTRY_FILE_TAG(builder);
70 
72  _Tree tree;
73  inline TreeBuilderImpl() : tree() {}
74 };
75 
77  std::vector<TreeBuilder> trees;
78  int num_feature;
79  int num_output_group;
80  bool random_forest_flag;
81  std::vector<std::pair<std::string, std::string>> cfg;
82  inline ModelBuilderImpl(int num_feature, int num_output_group,
83  bool random_forest_flag)
84  : trees(), num_feature(num_feature),
85  num_output_group(num_output_group),
86  random_forest_flag(random_forest_flag), cfg() {
87  CHECK_GT(num_feature, 0) << "ModelBuilder: num_feature must be positive";
88  CHECK_GT(num_output_group, 0)
89  << "ModelBuilder: num_output_group must be positive";
90  }
91 };
92 
93 TreeBuilder::TreeBuilder()
94  : pimpl(new TreeBuilderImpl()), ensemble_id(nullptr) {}
95 TreeBuilder::~TreeBuilder() {}
96 
97 void
98 TreeBuilder::CreateNode(int node_key) {
99  auto& nodes = pimpl->tree.nodes;
100  CHECK_EQ(nodes.count(node_key), 0) << "CreateNode: nodes with duplicate keys are not allowed";
101  nodes[node_key].reset(new _Node());
102 }
103 
104 void
105 TreeBuilder::DeleteNode(int node_key) {
106  auto& tree = pimpl->tree;
107  auto& nodes = tree.nodes;
108  CHECK_GT(nodes.count(node_key), 0) << "DeleteNode: no node found with node_key";
109  _Node* node = nodes[node_key].get();
110  if (tree.root == node) { // deleting root
111  tree.root = nullptr;
112  }
113  if (node->left_child != nullptr) { // deleting a parent
114  node->left_child->parent = nullptr;
115  }
116  if (node->right_child != nullptr) { // deleting a parent
117  node->right_child->parent = nullptr;
118  }
119  if (node == tree.root) { // deleting root
120  tree.root = nullptr;
121  nodes.clear();
122  } else {
123  nodes.erase(node_key);
124  }
125 }
126 
127 void
128 TreeBuilder::SetRootNode(int node_key) {
129  auto& tree = pimpl->tree;
130  auto& nodes = tree.nodes;
131  CHECK_GT(nodes.count(node_key), 0) << "SetRootNode: no node found with node_key";
132  _Node* node = nodes[node_key].get();
133  CHECK(!node->parent) << "SetRootNode: a root node cannot have a parent";
134  tree.root = node;
135 }
136 
137 void
138 TreeBuilder::SetNumericalTestNode(int node_key,
139  unsigned feature_id,
140  const char* opname, tl_float threshold,
141  bool default_left, int left_child_key,
142  int right_child_key) {
143  CHECK_GT(optable.count(opname), 0) << "No operator \"" << opname << "\" exists";
144  Operator op = optable.at(opname);
145  SetNumericalTestNode(node_key, feature_id, op, threshold, default_left,
146  left_child_key, right_child_key);
147 }
148 
149 void
150 TreeBuilder::SetNumericalTestNode(int node_key,
151  unsigned feature_id,
152  Operator op, tl_float threshold,
153  bool default_left, int left_child_key,
154  int right_child_key) {
155  auto& tree = pimpl->tree;
156  auto& nodes = tree.nodes;
157  CHECK_GT(nodes.count(node_key), 0) << "SetNumericalTestNode: no node found with node_key";
158  CHECK_GT(nodes.count(left_child_key), 0)
159  << "SetNumericalTestNode: no node found with left_child_key";
160  CHECK_GT(nodes.count(right_child_key), 0)
161  << "SetNumericalTestNode: no node found with right_child_key";
162  _Node* node = nodes[node_key].get();
163  _Node* left_child = nodes[left_child_key].get();
164  _Node* right_child = nodes[right_child_key].get();
165  CHECK(node->status == _Node::_Status::kEmpty)
166  << "SetNumericalTestNode: cannot modify a non-empty node";
167  CHECK(!left_child->parent)
168  << "SetNumericalTestNode: node designated as left child already has a parent";
169  CHECK(!right_child->parent)
170  << "SetNumericalTestNode: node designated as right child already has a parent";
171  CHECK(left_child != tree.root && right_child != tree.root)
172  << "SetNumericalTestNode: the root node cannot be a child";
173  node->status = _Node::_Status::kNumericalTest;
174  node->left_child = nodes[left_child_key].get();
175  node->left_child->parent = node;
176  node->right_child = nodes[right_child_key].get();
177  node->right_child->parent = node;
178  node->feature_id = feature_id;
179  node->default_left = default_left;
180  node->info.threshold = threshold;
181  node->op = op;
182 }
183 
184 void
185 TreeBuilder::SetCategoricalTestNode(int node_key,
186  unsigned feature_id,
187  const std::vector<uint32_t>& left_categories,
188  bool default_left, int left_child_key,
189  int right_child_key) {
190  auto &tree = pimpl->tree;
191  auto &nodes = tree.nodes;
192  CHECK_GT(nodes.count(node_key), 0) << "SetCategoricalTestNode: no node found with node_key";
193  CHECK_GT(nodes.count(left_child_key), 0)
194  << "SetCategoricalTestNode: no node found with left_child_key";
195  CHECK_GT(nodes.count(right_child_key), 0)
196  << "SetCategoricalTestNode: no node found with right_child_key";
197  _Node* node = nodes[node_key].get();
198  _Node* left_child = nodes[left_child_key].get();
199  _Node* right_child = nodes[right_child_key].get();
200  CHECK(node->status == _Node::_Status::kEmpty)
201  << "SetCategoricalTestNode: cannot modify a non-empty node";
202  CHECK(!left_child->parent)
203  << "SetCategoricalTestNode: node designated as left child already has a parent";
204  CHECK(!right_child->parent)
205  << "SetCategoricalTestNode: node designated as right child already has a parent";
206  CHECK(left_child != tree.root && right_child != tree.root)
207  << "SetCategoricalTestNode: the root node cannot be a child";
208  node->status = _Node::_Status::kCategoricalTest;
209  node->left_child = nodes[left_child_key].get();
210  node->left_child->parent = node;
211  node->right_child = nodes[right_child_key].get();
212  node->right_child->parent = node;
213  node->feature_id = feature_id;
214  node->default_left = default_left;
215  node->left_categories = left_categories;
216 }
217 
218 void
219 TreeBuilder::SetLeafNode(int node_key, tl_float leaf_value) {
220  auto& tree = pimpl->tree;
221  auto& nodes = tree.nodes;
222  CHECK_GT(nodes.count(node_key), 0) << "SetLeafNode: no node found with node_key";
223  _Node* node = nodes[node_key].get();
224  CHECK(node->status == _Node::_Status::kEmpty) << "SetLeafNode: cannot modify a non-empty node";
225  node->status = _Node::_Status::kLeaf;
226  node->info.leaf_value = leaf_value;
227 }
228 
229 void
230 TreeBuilder::SetLeafVectorNode(int node_key, const std::vector<tl_float>& leaf_vector) {
231  auto& tree = pimpl->tree;
232  auto& nodes = tree.nodes;
233  CHECK_GT(nodes.count(node_key), 0) << "SetLeafVectorNode: no node found with node_key";
234  _Node* node = nodes[node_key].get();
235  CHECK(node->status == _Node::_Status::kEmpty)
236  << "SetLeafVectorNode: cannot modify a non-empty node";
237  node->status = _Node::_Status::kLeaf;
238  node->leaf_vector = leaf_vector;
239 }
240 
241 ModelBuilder::ModelBuilder(int num_feature, int num_output_group,
242  bool random_forest_flag)
243  : pimpl(new ModelBuilderImpl(num_feature,
244  num_output_group,
245  random_forest_flag)) {}
246 ModelBuilder::~ModelBuilder() = default;
247 
248 void
249 ModelBuilder::SetModelParam(const char* name, const char* value) {
250  pimpl->cfg.emplace_back(name, value);
251 }
252 
253 int
254 ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) {
255  if (tree_builder == nullptr) {
256  LOG(FATAL) << "InsertTree: not a valid tree builder";
257  return -1;
258  }
259  if (tree_builder->ensemble_id != nullptr) {
260  LOG(FATAL) << "InsertTree: tree is already part of another ensemble";
261  return -1;
262  }
263 
264  // check bounds for feature indices
265  for (const auto& kv : tree_builder->pimpl->tree.nodes) {
266  const _Node::_Status status = kv.second->status;
267  if (status == _Node::_Status::kNumericalTest ||
268  status == _Node::_Status::kCategoricalTest) {
269  const int fid = static_cast<int>(kv.second->feature_id);
270  if (fid < 0 || fid >= pimpl->num_feature) {
271  LOG(FATAL) << "InsertTree: tree has an invalid split at node "
272  << kv.first << ": feature id " << kv.second->feature_id << " is out of bound";
273  return -1;
274  }
275  }
276  }
277 
278  // perform insertion
279  auto& trees = pimpl->trees;
280  if (index == -1) {
281  trees.push_back(std::move(*tree_builder));
282  tree_builder->ensemble_id = static_cast<void*>(this);
283  return static_cast<int>(trees.size());
284  } else {
285  if (static_cast<size_t>(index) <= trees.size()) {
286  trees.insert(trees.begin() + index, std::move(*tree_builder));
287  tree_builder->ensemble_id = static_cast<void*>(this);
288  return index;
289  } else {
290  LOG(FATAL) << "CreateTree: index out of bound";
291  return -1;
292  }
293  }
294 }
295 
298  return &pimpl->trees.at(index);
299 }
300 
301 const TreeBuilder*
302 ModelBuilder::GetTree(int index) const {
303  return &pimpl->trees.at(index);
304 }
305 
306 void
308  auto& trees = pimpl->trees;
309  CHECK_LT(static_cast<size_t>(index), trees.size()) << "DeleteTree: index out of bound";
310  trees.erase(trees.begin() + index);
311 }
312 
313 void
315  Model model;
316  model.num_feature = pimpl->num_feature;
317  model.num_output_group = pimpl->num_output_group;
318  model.random_forest_flag = pimpl->random_forest_flag;
319  // extra parameters
320  InitParamAndCheck(&model.param, pimpl->cfg);
321 
322  // flag to check consistent use of leaf vector
323  // 0: no leaf should use leaf vector
324  // 1: every leaf should use leaf vector
325  // -1: indeterminate
326  int8_t flag_leaf_vector = -1;
327 
328  for (const auto& _tree_builder : pimpl->trees) {
329  const auto& _tree = _tree_builder.pimpl->tree;
330  CHECK(_tree.root) << "CommitModel: a tree has no root node";
331  CHECK(_tree.root->status != _Node::_Status::kEmpty)
332  << "SetRootNode: cannot set an empty node as root";
333  model.trees.emplace_back();
334  Tree& tree = model.trees.back();
335  tree.Init();
336 
337  // assign node ID's so that a breadth-wise traversal would yield
338  // the monotonic sequence 0, 1, 2, ...
339  std::queue<std::pair<const _Node*, int>> Q; // (internal pointer, ID)
340  Q.push({_tree.root, 0}); // assign 0 to root
341  while (!Q.empty()) {
342  const _Node* node;
343  int nid;
344  std::tie(node, nid) = Q.front();
345  Q.pop();
346  CHECK(node->status != _Node::_Status::kEmpty)
347  << "CommitModel: encountered an empty node in the middle of a tree";
348  if (node->status == _Node::_Status::kNumericalTest) {
349  CHECK(node->left_child) << "CommitModel: a test node lacks a left child";
350  CHECK(node->right_child) << "CommitModel: a test node lacks a right child";
351  CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent";
352  CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent";
353  tree.AddChilds(nid);
354  tree.SetNumericalSplit(nid, node->feature_id, node->info.threshold,
355  node->default_left, node->op);
356  Q.push({node->left_child, tree.LeftChild(nid)});
357  Q.push({node->right_child, tree.RightChild(nid)});
358  } else if (node->status == _Node::_Status::kCategoricalTest) {
359  CHECK(node->left_child) << "CommitModel: a test node lacks a left child";
360  CHECK(node->right_child) << "CommitModel: a test node lacks a right child";
361  CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent";
362  CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent";
363  tree.AddChilds(nid);
364  tree.SetCategoricalSplit(nid, node->feature_id, node->default_left,
365  false, node->left_categories);
366  Q.push({node->left_child, tree.LeftChild(nid)});
367  Q.push({node->right_child, tree.RightChild(nid)});
368  } else { // leaf node
369  CHECK(node->left_child == nullptr && node->right_child == nullptr)
370  << "CommitModel: a leaf node cannot have children";
371  if (!node->leaf_vector.empty()) { // leaf vector exists
372  CHECK_NE(flag_leaf_vector, 0)
373  << "CommitModel: Inconsistent use of leaf vector: if one leaf node uses a leaf vector, "
374  << "*every* leaf node must use a leaf vector";
375  flag_leaf_vector = 1; // now every leaf must use leaf vector
376  CHECK_EQ(node->leaf_vector.size(), model.num_output_group)
377  << "CommitModel: The length of leaf vector must be identical to the number of output "
378  << "groups";
379  tree.SetLeafVector(nid, node->leaf_vector);
380  } else { // ordinary leaf
381  CHECK_NE(flag_leaf_vector, 1)
382  << "CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf "
383  << "vector, *no other* leaf node can use a leaf vector";
384  flag_leaf_vector = 0; // now no leaf can use leaf vector
385  tree.SetLeaf(nid, node->info.leaf_value);
386  }
387  }
388  }
389  }
390  if (flag_leaf_vector == 0) {
391  if (model.num_output_group > 1) {
392  // multi-class classification with gradient boosted trees
393  CHECK(!model.random_forest_flag)
394  << "To use a random forest for multi-class classification, each leaf node must output a "
395  << "leaf vector specifying a probability distribution";
396  CHECK_EQ(pimpl->trees.size() % model.num_output_group, 0)
397  << "For multi-class classifiers with gradient boosted trees, the number of trees must be "
398  << "evenly divisible by the number of output groups";
399  }
400  } else if (flag_leaf_vector == 1) {
401  // multi-class classification with a random forest
402  CHECK(model.random_forest_flag)
403  << "In multi-class classifiers with gradient boosted trees, each leaf node must output a "
404  << "single floating-point value.";
405  } else {
406  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
407  }
408  *out_model = std::move(model);
409 }
410 
411 } // namespace frontend
412 } // namespace treelite
void CommitModel(Model *out_model)
finalize the model and produce the in-memory representation
Definition: builder.cc:314
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:419
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:476
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
tree builder class
Definition: frontend.h:70
std::vector< Tree > trees
member trees
Definition: tree.h:411
ModelParam param
extra parameters
Definition: tree.h:424
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:249
in-memory representation of a decision tree
Definition: tree.h:80
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:297
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
Definition: builder.cc:254
void SetLeafVector(int nid, const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree_impl.h:691
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
Definition: tree_impl.h:649
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:422
void DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:307
void SetLeaf(int nid, tl_float value)
set the leaf value of the node
Definition: tree_impl.h:682
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Definition: optable.cc:12
void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:635
int LeftChild(int nid) const
Getters.
Definition: tree_impl.h:524
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree_impl.h:529
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:488
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:416
Operator
comparison operators
Definition: base.h:24