treelite
builder.cc
Go to the documentation of this file.
1 
8 #include <treelite/frontend.h>
9 #include <treelite/tree.h>
10 #include <dmlc/registry.h>
11 #include <memory>
12 #include <queue>
13 #include "../c_api/c_api_error.h"
14 
15 #define CHECK_EARLY_RETURN(x, msg) \
16  if (!(x)) { \
17  TreeliteAPISetLastError(msg); \
18  dmlc::LogMessage(__FILE__, __LINE__).stream() << msg; \
19  return false; \
20  }
21 
22 /* data structures with underscore prefixes are internal use only and
23  do not have external linkage */
24 namespace {
25 
26 struct _Node {
27  enum class _Status : int8_t {
28  kEmpty, kNumericalTest, kCategoricalTest, kLeaf
29  };
30  union _Info {
31  treelite::tl_float leaf_value; // for leaf nodes
32  treelite::tl_float threshold; // for non-leaf nodes
33  };
34  /*
35  * leaf vector: only used for random forests with multi-class classification
36  */
37  std::vector<treelite::tl_float> leaf_vector;
38  _Status status;
39  /* pointers to parent, left and right children */
40  _Node* parent;
41  _Node* left_child;
42  _Node* right_child;
43  // split feature index
44  unsigned feature_id;
45  // default direction for missing values
46  bool default_left;
47  // extra info: leaf value or threshold
48  _Info info;
49  // (for numerical split)
50  // operator to use for expression of form [fval] OP [threshold]
51  // If the expression evaluates to true, take the left child;
52  // otherwise, take the right child.
54  // (for categorical split)
55  // list of all categories that belong to the left child node.
56  // All others not in the list belong to the right child node.
57  // Categories are integers ranging from 0 to (n-1), where n is the number of
58  // categories in that particular feature. Let's assume n <= 64.
59  std::vector<uint32_t> left_categories;
60 
61  inline _Node()
62  : status(_Status::kEmpty),
63  parent(nullptr), left_child(nullptr), right_child(nullptr) {}
64 };
65 
66 struct _Tree {
67  _Node* root;
68  std::unordered_map<int, std::shared_ptr<_Node>> nodes;
69  inline _Tree() : root(nullptr), nodes() {}
70 };
71 
72 } // namespace anonymous
73 
74 namespace treelite {
75 namespace frontend {
76 
77 DMLC_REGISTRY_FILE_TAG(builder);
78 
80  _Tree tree;
81  inline TreeBuilderImpl() : tree() {}
82 };
83 
85  std::vector<TreeBuilder> trees;
86  int num_feature;
87  int num_output_group;
88  bool random_forest_flag;
89  std::vector<std::pair<std::string, std::string>> cfg;
90  inline ModelBuilderImpl(int num_feature, int num_output_group,
91  bool random_forest_flag)
92  : trees(), num_feature(num_feature),
93  num_output_group(num_output_group),
94  random_forest_flag(random_forest_flag), cfg() {
95  CHECK_GT(num_feature, 0) << "ModelBuilder: num_feature must be positive";
96  CHECK_GT(num_output_group, 0)
97  << "ModelBuilder: num_output_group must be positive";
98  }
99 };
100 
101 TreeBuilder::TreeBuilder()
102  : pimpl(common::make_unique<TreeBuilderImpl>()), ensemble_id(nullptr) {}
103 TreeBuilder::~TreeBuilder() {}
104 
105 bool
106 TreeBuilder::CreateNode(int node_key) {
107  auto& nodes = pimpl->tree.nodes;
108  CHECK_EARLY_RETURN(nodes.count(node_key) == 0,
109  "CreateNode: nodes with duplicate keys are not allowed");
110  nodes[node_key] = common::make_unique<_Node>();
111  return true;
112 }
113 
114 bool
115 TreeBuilder::DeleteNode(int node_key) {
116  auto& tree = pimpl->tree;
117  auto& nodes = tree.nodes;
118  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
119  "DeleteNode: no node found with node_key");
120  _Node* node = nodes[node_key].get();
121  if (tree.root == node) { // deleting root
122  tree.root = nullptr;
123  }
124  if (node->left_child != nullptr) { // deleting a parent
125  node->left_child->parent = nullptr;
126  }
127  if (node->right_child != nullptr) { // deleting a parent
128  node->right_child->parent = nullptr;
129  }
130  nodes.erase(node_key);
131  return true;
132 }
133 
134 bool
135 TreeBuilder::SetRootNode(int node_key) {
136  auto& tree = pimpl->tree;
137  auto& nodes = tree.nodes;
138  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
139  "SetRootNode: no node found with node_key");
140  _Node* node = nodes[node_key].get();
141  CHECK_EARLY_RETURN(node->status != _Node::_Status::kLeaf,
142  "SetRootNode: cannot set a leaf node as root");
143  CHECK_EARLY_RETURN(node->parent == nullptr,
144  "SetRootNode: a root node cannot have a parent");
145  tree.root = node;
146  return true;
147 }
148 
149 bool
150 TreeBuilder::SetNumericalTestNode(int node_key,
151  unsigned feature_id,
152  Operator op, tl_float threshold,
153  bool default_left, int left_child_key,
154  int right_child_key) {
155  auto& tree = pimpl->tree;
156  auto& nodes = tree.nodes;
157  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
158  "SetNumericalTestNode: no node found with node_key");
159  CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
160  "SetNumericalTestNode: no node found with left_child_key");
161  CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
162  "SetNumericalTestNode: no node found with right_child_key");
163  _Node* node = nodes[node_key].get();
164  _Node* left_child = nodes[left_child_key].get();
165  _Node* right_child = nodes[right_child_key].get();
166  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
167  "SetNumericalTestNode: cannot modify a non-empty node");
168  CHECK_EARLY_RETURN(left_child->parent == nullptr,
169  "SetNumericalTestNode: node designated as left child already has "
170  "a parent");
171  CHECK_EARLY_RETURN(right_child->parent == nullptr,
172  "SetNumericalTestNode: node designated as right child already has "
173  "a parent");
174  CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
175  "SetNumericalTestNode: the root node cannot be a child");
176  node->status = _Node::_Status::kNumericalTest;
177  node->left_child = nodes[left_child_key].get();
178  node->left_child->parent = node;
179  node->right_child = nodes[right_child_key].get();
180  node->right_child->parent = node;
181  node->feature_id = feature_id;
182  node->default_left = default_left;
183  node->info.threshold = threshold;
184  node->op = op;
185  return true;
186 }
187 
188 bool
189 TreeBuilder::SetCategoricalTestNode(int node_key,
190  unsigned feature_id,
191  const std::vector<uint32_t>& left_categories,
192  bool default_left, int left_child_key,
193  int right_child_key) {
194  auto& tree = pimpl->tree;
195  auto& nodes = tree.nodes;
196  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
197  "SetCategoricalTestNode: no node found with node_key");
198  CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
199  "SetCategoricalTestNode: no node found with left_child_key");
200  CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
201  "SetCategoricalTestNode: no node found with right_child_key");
202  _Node* node = nodes[node_key].get();
203  _Node* left_child = nodes[left_child_key].get();
204  _Node* right_child = nodes[right_child_key].get();
205  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
206  "SetCategoricalTestNode: cannot modify a non-empty node");
207  CHECK_EARLY_RETURN(left_child->parent == nullptr,
208  "SetCategoricalTestNode: node designated as left child already "
209  "has a parent");
210  CHECK_EARLY_RETURN(right_child->parent == nullptr,
211  "SetCategoricalTestNode: node designated as right child already "
212  "has a parent");
213  CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
214  "SetCategoricalTestNode: the root node cannot be a child");
215  node->status = _Node::_Status::kCategoricalTest;
216  node->left_child = nodes[left_child_key].get();
217  node->left_child->parent = node;
218  node->right_child = nodes[right_child_key].get();
219  node->right_child->parent = node;
220  node->feature_id = feature_id;
221  node->default_left = default_left;
222  node->left_categories = left_categories;
223  return true;
224 }
225 
226 bool
227 TreeBuilder::SetLeafNode(int node_key, tl_float leaf_value) {
228  auto& tree = pimpl->tree;
229  auto& nodes = tree.nodes;
230  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
231  "SetLeafNode: no node found with node_key");
232  _Node* node = nodes[node_key].get();
233  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
234  "SetLeafNode: cannot modify a non-empty node");
235  node->status = _Node::_Status::kLeaf;
236  node->info.leaf_value = leaf_value;
237  return true;
238 }
239 
240 bool
241 TreeBuilder::SetLeafVectorNode(int node_key,
242  const std::vector<tl_float>& leaf_vector) {
243  auto& tree = pimpl->tree;
244  auto& nodes = tree.nodes;
245  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
246  "SetLeafVectorNode: no node found with node_key");
247  _Node* node = nodes[node_key].get();
248  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
249  "SetLeafVectorNode: cannot modify a non-empty node");
250  node->status = _Node::_Status::kLeaf;
251  node->leaf_vector = leaf_vector;
252  return true;
253 }
254 
255 ModelBuilder::ModelBuilder(int num_feature, int num_output_group,
256  bool random_forest_flag)
257  : pimpl(common::make_unique<ModelBuilderImpl>(num_feature,
258  num_output_group,
259  random_forest_flag)) {}
260 ModelBuilder::~ModelBuilder() {}
261 
262 void
263 ModelBuilder::SetModelParam(const char* name, const char* value) {
264  pimpl->cfg.emplace_back(name, value);
265 }
266 
267 int
268 ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) {
269  if (tree_builder == nullptr) {
270  const char* msg = "InsertTree: not a valid tree builder";
271  LOG(INFO) << msg;
273  return -1; // fail
274  }
275  if (tree_builder->ensemble_id != nullptr) {
276  const char* msg = "InsertTree: tree is already part of another ensemble";
277  LOG(INFO) << msg;
279  return -1; // fail
280  }
281 
282  // check bounds for feature indices
283  for (const auto& kv : tree_builder->pimpl->tree.nodes) {
284  const _Node::_Status status = kv.second->status;
285  if (status == _Node::_Status::kNumericalTest ||
286  status == _Node::_Status::kCategoricalTest) {
287  const int fid = static_cast<int>(kv.second->feature_id);
288  if (fid < 0 || fid >= pimpl->num_feature) {
289  std::ostringstream oss;
290  oss << "InsertTree: tree has an invalid split at node "
291  << kv.first << ": feature id " << kv.second->feature_id
292  << " is out of bound";
293  const std::string str = oss.str();
294  const char* msg = str.c_str();
295  LOG(INFO) << msg;
297  return -1; // fail
298  }
299  }
300  }
301 
302  // perform insertion
303  auto& trees = pimpl->trees;
304  if (index == -1) {
305  trees.push_back(std::move(*tree_builder));
306  tree_builder->ensemble_id = static_cast<void*>(this);
307  return static_cast<int>(trees.size());
308  } else {
309  if (static_cast<size_t>(index) <= trees.size()) {
310  trees.insert(trees.begin() + index, std::move(*tree_builder));
311  tree_builder->ensemble_id = static_cast<void*>(this);
312  return index;
313  } else {
314  LOG(INFO) << "CreateTree: index out of bound";
315  return -1; // fail
316  }
317  }
318 }
319 
322  return pimpl->trees[index];
323 }
324 
325 const TreeBuilder&
326 ModelBuilder::GetTree(int index) const {
327  return pimpl->trees[index];
328 }
329 
330 bool
332  auto& trees = pimpl->trees;
333  CHECK_EARLY_RETURN(static_cast<size_t>(index) < trees.size(),
334  "DeleteTree: index out of bound");
335  trees.erase(trees.begin() + index);
336  return true;
337 }
338 
339 bool
341  Model model;
342  model.num_feature = pimpl->num_feature;
343  model.num_output_group = pimpl->num_output_group;
344  model.random_forest_flag = pimpl->random_forest_flag;
345  // extra parameters
346  InitParamAndCheck(&model.param, pimpl->cfg);
347 
348  // flag to check consistent use of leaf vector
349  // 0: no leaf should use leaf vector
350  // 1: every leaf should use leaf vector
351  // -1: indeterminate
352  int8_t flag_leaf_vector = -1;
353 
354  for (const auto& _tree_builder : pimpl->trees) {
355  const auto& _tree = _tree_builder.pimpl->tree;
356  CHECK_EARLY_RETURN(_tree.root != nullptr,
357  "CommitModel: a tree has no root node");
358  model.trees.emplace_back();
359  Tree& tree = model.trees.back();
360  tree.Init();
361 
362  // assign node ID's so that a breadth-wise traversal would yield
363  // the monotonic sequence 0, 1, 2, ...
364  std::queue<std::pair<const _Node*, int>> Q; // (internal pointer, ID)
365  Q.push({_tree.root, 0}); // assign 0 to root
366  while (!Q.empty()) {
367  const _Node* node;
368  int nid;
369  std::tie(node, nid) = Q.front(); Q.pop();
370  CHECK_EARLY_RETURN(node->status != _Node::_Status::kEmpty,
371  "CommitModel: encountered an empty node in the middle of a tree");
372  if (node->status == _Node::_Status::kNumericalTest) {
373  CHECK_EARLY_RETURN(node->left_child != nullptr,
374  "CommitModel: a test node lacks a left child");
375  CHECK_EARLY_RETURN(node->right_child != nullptr,
376  "CommitModel: a test node lacks a right child");
377  CHECK_EARLY_RETURN(node->left_child->parent == node,
378  "CommitModel: left child has wrong parent");
379  CHECK_EARLY_RETURN(node->right_child->parent == node,
380  "CommitModel: right child has wrong parent");
381  tree.AddChilds(nid);
382  tree[nid].set_numerical_split(node->feature_id, node->info.threshold,
383  node->default_left, node->op);
384  Q.push({node->left_child, tree[nid].cleft()});
385  Q.push({node->right_child, tree[nid].cright()});
386  } else if (node->status == _Node::_Status::kCategoricalTest) {
387  CHECK_EARLY_RETURN(node->left_child != nullptr,
388  "CommitModel: a test node lacks a left child");
389  CHECK_EARLY_RETURN(node->right_child != nullptr,
390  "CommitModel: a test node lacks a right child");
391  CHECK_EARLY_RETURN(node->left_child->parent == node,
392  "CommitModel: left child has wrong parent");
393  CHECK_EARLY_RETURN(node->right_child->parent == node,
394  "CommitModel: right child has wrong parent");
395  tree.AddChilds(nid);
396  tree[nid].set_categorical_split(node->feature_id, node->default_left,
397  node->left_categories);
398  Q.push({node->left_child, tree[nid].cleft()});
399  Q.push({node->right_child, tree[nid].cright()});
400  } else { // leaf node
401  CHECK_EARLY_RETURN(node->left_child == nullptr
402  && node->right_child == nullptr,
403  "CommitModel: a leaf node cannot have children");
404  if (!node->leaf_vector.empty()) { // leaf vector exists
405  CHECK_EARLY_RETURN(flag_leaf_vector != 0,
406  "CommitModel: Inconsistent use of leaf vector: "
407  "if one leaf node uses a leaf vector, "
408  "*every* leaf node must use a leaf vector");
409  flag_leaf_vector = 1; // now every leaf must use leaf vector
410  CHECK_EARLY_RETURN(node->leaf_vector.size() == model.num_output_group,
411  "CommitModel: The length of leaf vector must be "
412  "identical to the number of output groups");
413  tree[nid].set_leaf_vector(node->leaf_vector);
414  } else { // ordinary leaf
415  CHECK_EARLY_RETURN(flag_leaf_vector != 1,
416  "CommitModel: Inconsistent use of leaf vector: "
417  "if one leaf node does not use a leaf vector, "
418  "*no other* leaf node can use a leaf vector");
419  flag_leaf_vector = 0; // now no leaf can use leaf vector
420  tree[nid].set_leaf(node->info.leaf_value);
421  }
422  }
423  }
424  }
425  if (flag_leaf_vector == 0) {
426  if (model.num_output_group > 1) {
427  // multiclass classification with gradient boosted trees
428  CHECK_EARLY_RETURN(!model.random_forest_flag,
429  "To use a random forest for multi-class classification, each leaf "
430  "node must output a leaf vector specifying a probability "
431  "distribution");
432  CHECK_EARLY_RETURN(pimpl->trees.size() % model.num_output_group == 0,
433  "For multi-class classifiers with gradient boosted trees, the number "
434  "of trees must be evenly divisible by the number of output groups");
435  }
436  } else if (flag_leaf_vector == 1) {
437  // multiclass classification with a random forest
438  CHECK_EARLY_RETURN(model.random_forest_flag,
439  "In multi-class classifiers with gradient boosted trees, each leaf "
440  "node must output a single floating-point value.");
441  } else {
442  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
443  }
444  *out_model = std::move(model);
445  return true;
446 }
447 
448 } // namespace frontend
449 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree.h:235
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
tree builder class
Definition: frontend.h:62
std::vector< Tree > trees
member trees
Definition: tree.h:353
ModelParam param
extra parameters
Definition: tree.h:366
model structure for tree
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:263
in-memory representation of a decision tree
Definition: tree.h:19
bool DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:331
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
Definition: builder.cc:268
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:364
void TreeliteAPISetLastError(const char *msg)
Set the last error message needed by C API.
Definition: c_api_error.cc:20
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:245
bool CommitModel(Model *out_model)
finalize the model and produce the in-memory representation
Definition: builder.cc:340
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:321
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:358
Operator
comparison operators
Definition: base.h:23