treelite
builder.cc
Go to the documentation of this file.
1 
8 #include <dmlc/registry.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <memory>
12 #include <queue>
13 #include "../c_api/c_api_error.h"
14 
15 #define CHECK_EARLY_RETURN(x, msg) \
16  if (!(x)) { \
17  TreeliteAPISetLastError(msg); \
18  dmlc::LogMessage(__FILE__, __LINE__).stream() << msg; \
19  return false; \
20  }
21 
22 /* data structures with underscore prefixes are internal use only and
23  do not have external linkage */
24 namespace {
25 
26 struct _Node {
27  enum class _Status : int8_t {
28  kEmpty, kNumericalTest, kCategoricalTest, kLeaf
29  };
30  union _Info {
31  treelite::tl_float leaf_value; // for leaf nodes
32  treelite::tl_float threshold; // for non-leaf nodes
33  };
34  /*
35  * leaf vector: only used for random forests with multi-class classification
36  */
37  std::vector<treelite::tl_float> leaf_vector;
38  _Status status;
39  /* pointers to parent, left and right children */
40  _Node* parent;
41  _Node* left_child;
42  _Node* right_child;
43  // split feature index
44  unsigned feature_id;
45  // default direction for missing values
46  bool default_left;
47  // extra info: leaf value or threshold
48  _Info info;
49  // (for numerical split)
50  // operator to use for expression of form [fval] OP [threshold]
51  // If the expression evaluates to true, take the left child;
52  // otherwise, take the right child.
54  // (for categorical split)
55  // list of all categories that belong to the left child node.
56  // All others not in the list belong to the right child node.
57  // Categories are integers ranging from 0 to (n-1), where n is the number of
58  // categories in that particular feature. Let's assume n <= 64.
59  std::vector<uint32_t> left_categories;
60 
61  inline _Node()
62  : status(_Status::kEmpty),
63  parent(nullptr), left_child(nullptr), right_child(nullptr) {}
64 };
65 
66 struct _Tree {
67  _Node* root;
68  std::unordered_map<int, std::shared_ptr<_Node>> nodes;
69  inline _Tree() : root(nullptr), nodes() {}
70 };
71 
72 } // anonymous namespace
73 
74 namespace treelite {
75 namespace frontend {
76 
77 DMLC_REGISTRY_FILE_TAG(builder);
78 
80  _Tree tree;
81  inline TreeBuilderImpl() : tree() {}
82 };
83 
85  std::vector<TreeBuilder> trees;
86  int num_feature;
87  int num_output_group;
88  bool random_forest_flag;
89  std::vector<std::pair<std::string, std::string>> cfg;
90  inline ModelBuilderImpl(int num_feature, int num_output_group,
91  bool random_forest_flag)
92  : trees(), num_feature(num_feature),
93  num_output_group(num_output_group),
94  random_forest_flag(random_forest_flag), cfg() {
95  CHECK_GT(num_feature, 0) << "ModelBuilder: num_feature must be positive";
96  CHECK_GT(num_output_group, 0)
97  << "ModelBuilder: num_output_group must be positive";
98  }
99 };
100 
101 TreeBuilder::TreeBuilder()
102  : pimpl(common::make_unique<TreeBuilderImpl>()), ensemble_id(nullptr) {}
103 TreeBuilder::~TreeBuilder() {}
104 
105 bool
106 TreeBuilder::CreateNode(int node_key) {
107  auto& nodes = pimpl->tree.nodes;
108  CHECK_EARLY_RETURN(nodes.count(node_key) == 0,
109  "CreateNode: nodes with duplicate keys are not allowed");
110  nodes[node_key] = common::make_unique<_Node>();
111  return true;
112 }
113 
114 bool
115 TreeBuilder::DeleteNode(int node_key) {
116  auto& tree = pimpl->tree;
117  auto& nodes = tree.nodes;
118  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
119  "DeleteNode: no node found with node_key");
120  _Node* node = nodes[node_key].get();
121  if (tree.root == node) { // deleting root
122  tree.root = nullptr;
123  }
124  if (node->left_child != nullptr) { // deleting a parent
125  node->left_child->parent = nullptr;
126  }
127  if (node->right_child != nullptr) { // deleting a parent
128  node->right_child->parent = nullptr;
129  }
130  nodes.erase(node_key);
131  return true;
132 }
133 
134 bool
135 TreeBuilder::SetRootNode(int node_key) {
136  auto& tree = pimpl->tree;
137  auto& nodes = tree.nodes;
138  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
139  "SetRootNode: no node found with node_key");
140  _Node* node = nodes[node_key].get();
141  CHECK_EARLY_RETURN(node->parent == nullptr,
142  "SetRootNode: a root node cannot have a parent");
143  tree.root = node;
144  return true;
145 }
146 
147 bool
148 TreeBuilder::SetNumericalTestNode(int node_key,
149  unsigned feature_id,
150  Operator op, tl_float threshold,
151  bool default_left, int left_child_key,
152  int right_child_key) {
153  auto& tree = pimpl->tree;
154  auto& nodes = tree.nodes;
155  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
156  "SetNumericalTestNode: no node found with node_key");
157  CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
158  "SetNumericalTestNode: no node found with left_child_key");
159  CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
160  "SetNumericalTestNode: no node found with right_child_key");
161  _Node* node = nodes[node_key].get();
162  _Node* left_child = nodes[left_child_key].get();
163  _Node* right_child = nodes[right_child_key].get();
164  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
165  "SetNumericalTestNode: cannot modify a non-empty node");
166  CHECK_EARLY_RETURN(left_child->parent == nullptr,
167  "SetNumericalTestNode: node designated as left child already has "
168  "a parent");
169  CHECK_EARLY_RETURN(right_child->parent == nullptr,
170  "SetNumericalTestNode: node designated as right child already has "
171  "a parent");
172  CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
173  "SetNumericalTestNode: the root node cannot be a child");
174  node->status = _Node::_Status::kNumericalTest;
175  node->left_child = nodes[left_child_key].get();
176  node->left_child->parent = node;
177  node->right_child = nodes[right_child_key].get();
178  node->right_child->parent = node;
179  node->feature_id = feature_id;
180  node->default_left = default_left;
181  node->info.threshold = threshold;
182  node->op = op;
183  return true;
184 }
185 
186 bool
187 TreeBuilder::SetCategoricalTestNode(int node_key,
188  unsigned feature_id,
189  const std::vector<uint32_t>& left_categories,
190  bool default_left, int left_child_key,
191  int right_child_key) {
192  auto& tree = pimpl->tree;
193  auto& nodes = tree.nodes;
194  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
195  "SetCategoricalTestNode: no node found with node_key");
196  CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
197  "SetCategoricalTestNode: no node found with left_child_key");
198  CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
199  "SetCategoricalTestNode: no node found with right_child_key");
200  _Node* node = nodes[node_key].get();
201  _Node* left_child = nodes[left_child_key].get();
202  _Node* right_child = nodes[right_child_key].get();
203  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
204  "SetCategoricalTestNode: cannot modify a non-empty node");
205  CHECK_EARLY_RETURN(left_child->parent == nullptr,
206  "SetCategoricalTestNode: node designated as left child already "
207  "has a parent");
208  CHECK_EARLY_RETURN(right_child->parent == nullptr,
209  "SetCategoricalTestNode: node designated as right child already "
210  "has a parent");
211  CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
212  "SetCategoricalTestNode: the root node cannot be a child");
213  node->status = _Node::_Status::kCategoricalTest;
214  node->left_child = nodes[left_child_key].get();
215  node->left_child->parent = node;
216  node->right_child = nodes[right_child_key].get();
217  node->right_child->parent = node;
218  node->feature_id = feature_id;
219  node->default_left = default_left;
220  node->left_categories = left_categories;
221  return true;
222 }
223 
224 bool
225 TreeBuilder::SetLeafNode(int node_key, tl_float leaf_value) {
226  auto& tree = pimpl->tree;
227  auto& nodes = tree.nodes;
228  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
229  "SetLeafNode: no node found with node_key");
230  _Node* node = nodes[node_key].get();
231  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
232  "SetLeafNode: cannot modify a non-empty node");
233  node->status = _Node::_Status::kLeaf;
234  node->info.leaf_value = leaf_value;
235  return true;
236 }
237 
238 bool
239 TreeBuilder::SetLeafVectorNode(int node_key,
240  const std::vector<tl_float>& leaf_vector) {
241  auto& tree = pimpl->tree;
242  auto& nodes = tree.nodes;
243  CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
244  "SetLeafVectorNode: no node found with node_key");
245  _Node* node = nodes[node_key].get();
246  CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
247  "SetLeafVectorNode: cannot modify a non-empty node");
248  node->status = _Node::_Status::kLeaf;
249  node->leaf_vector = leaf_vector;
250  return true;
251 }
252 
253 ModelBuilder::ModelBuilder(int num_feature, int num_output_group,
254  bool random_forest_flag)
255  : pimpl(common::make_unique<ModelBuilderImpl>(num_feature,
256  num_output_group,
257  random_forest_flag)) {}
258 ModelBuilder::~ModelBuilder() {}
259 
260 void
261 ModelBuilder::SetModelParam(const char* name, const char* value) {
262  pimpl->cfg.emplace_back(name, value);
263 }
264 
265 int
266 ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) {
267  if (tree_builder == nullptr) {
268  const char* msg = "InsertTree: not a valid tree builder";
269  LOG(INFO) << msg;
270  TreeliteAPISetLastError(msg);
271  return -1; // fail
272  }
273  if (tree_builder->ensemble_id != nullptr) {
274  const char* msg = "InsertTree: tree is already part of another ensemble";
275  LOG(INFO) << msg;
276  TreeliteAPISetLastError(msg);
277  return -1; // fail
278  }
279 
280  // check bounds for feature indices
281  for (const auto& kv : tree_builder->pimpl->tree.nodes) {
282  const _Node::_Status status = kv.second->status;
283  if (status == _Node::_Status::kNumericalTest ||
284  status == _Node::_Status::kCategoricalTest) {
285  const int fid = static_cast<int>(kv.second->feature_id);
286  if (fid < 0 || fid >= pimpl->num_feature) {
287  std::ostringstream oss;
288  oss << "InsertTree: tree has an invalid split at node "
289  << kv.first << ": feature id " << kv.second->feature_id
290  << " is out of bound";
291  const std::string str = oss.str();
292  const char* msg = str.c_str();
293  LOG(INFO) << msg;
294  TreeliteAPISetLastError(msg);
295  return -1; // fail
296  }
297  }
298  }
299 
300  // perform insertion
301  auto& trees = pimpl->trees;
302  if (index == -1) {
303  trees.push_back(std::move(*tree_builder));
304  tree_builder->ensemble_id = static_cast<void*>(this);
305  return static_cast<int>(trees.size());
306  } else {
307  if (static_cast<size_t>(index) <= trees.size()) {
308  trees.insert(trees.begin() + index, std::move(*tree_builder));
309  tree_builder->ensemble_id = static_cast<void*>(this);
310  return index;
311  } else {
312  LOG(INFO) << "CreateTree: index out of bound";
313  return -1; // fail
314  }
315  }
316 }
317 
320  return pimpl->trees[index];
321 }
322 
323 const TreeBuilder&
324 ModelBuilder::GetTree(int index) const {
325  return pimpl->trees[index];
326 }
327 
328 bool
330  auto& trees = pimpl->trees;
331  CHECK_EARLY_RETURN(static_cast<size_t>(index) < trees.size(),
332  "DeleteTree: index out of bound");
333  trees.erase(trees.begin() + index);
334  return true;
335 }
336 
337 bool
339  Model model;
340  model.num_feature = pimpl->num_feature;
341  model.num_output_group = pimpl->num_output_group;
342  model.random_forest_flag = pimpl->random_forest_flag;
343  // extra parameters
344  InitParamAndCheck(&model.param, pimpl->cfg);
345 
346  // flag to check consistent use of leaf vector
347  // 0: no leaf should use leaf vector
348  // 1: every leaf should use leaf vector
349  // -1: indeterminate
350  int8_t flag_leaf_vector = -1;
351 
352  for (const auto& _tree_builder : pimpl->trees) {
353  const auto& _tree = _tree_builder.pimpl->tree;
354  CHECK_EARLY_RETURN(_tree.root != nullptr,
355  "CommitModel: a tree has no root node");
356  model.trees.emplace_back();
357  Tree& tree = model.trees.back();
358  tree.Init();
359 
360  // assign node ID's so that a breadth-wise traversal would yield
361  // the monotonic sequence 0, 1, 2, ...
362  std::queue<std::pair<const _Node*, int>> Q; // (internal pointer, ID)
363  Q.push({_tree.root, 0}); // assign 0 to root
364  while (!Q.empty()) {
365  const _Node* node;
366  int nid;
367  std::tie(node, nid) = Q.front(); Q.pop();
368  CHECK_EARLY_RETURN(node->status != _Node::_Status::kEmpty,
369  "CommitModel: encountered an empty node in the middle of a tree");
370  if (node->status == _Node::_Status::kNumericalTest) {
371  CHECK_EARLY_RETURN(node->left_child != nullptr,
372  "CommitModel: a test node lacks a left child");
373  CHECK_EARLY_RETURN(node->right_child != nullptr,
374  "CommitModel: a test node lacks a right child");
375  CHECK_EARLY_RETURN(node->left_child->parent == node,
376  "CommitModel: left child has wrong parent");
377  CHECK_EARLY_RETURN(node->right_child->parent == node,
378  "CommitModel: right child has wrong parent");
379  tree.AddChilds(nid);
380  tree[nid].set_numerical_split(node->feature_id, node->info.threshold,
381  node->default_left, node->op);
382  Q.push({node->left_child, tree[nid].cleft()});
383  Q.push({node->right_child, tree[nid].cright()});
384  } else if (node->status == _Node::_Status::kCategoricalTest) {
385  CHECK_EARLY_RETURN(node->left_child != nullptr,
386  "CommitModel: a test node lacks a left child");
387  CHECK_EARLY_RETURN(node->right_child != nullptr,
388  "CommitModel: a test node lacks a right child");
389  CHECK_EARLY_RETURN(node->left_child->parent == node,
390  "CommitModel: left child has wrong parent");
391  CHECK_EARLY_RETURN(node->right_child->parent == node,
392  "CommitModel: right child has wrong parent");
393  tree.AddChilds(nid);
394  tree[nid].set_categorical_split(node->feature_id, node->default_left,
395  false, node->left_categories);
396  Q.push({node->left_child, tree[nid].cleft()});
397  Q.push({node->right_child, tree[nid].cright()});
398  } else { // leaf node
399  CHECK_EARLY_RETURN(node->left_child == nullptr
400  && node->right_child == nullptr,
401  "CommitModel: a leaf node cannot have children");
402  if (!node->leaf_vector.empty()) { // leaf vector exists
403  CHECK_EARLY_RETURN(flag_leaf_vector != 0,
404  "CommitModel: Inconsistent use of leaf vector: "
405  "if one leaf node uses a leaf vector, "
406  "*every* leaf node must use a leaf vector");
407  flag_leaf_vector = 1; // now every leaf must use leaf vector
408  CHECK_EARLY_RETURN(node->leaf_vector.size() == model.num_output_group,
409  "CommitModel: The length of leaf vector must be "
410  "identical to the number of output groups");
411  tree[nid].set_leaf_vector(node->leaf_vector);
412  } else { // ordinary leaf
413  CHECK_EARLY_RETURN(flag_leaf_vector != 1,
414  "CommitModel: Inconsistent use of leaf vector: "
415  "if one leaf node does not use a leaf vector, "
416  "*no other* leaf node can use a leaf vector");
417  flag_leaf_vector = 0; // now no leaf can use leaf vector
418  tree[nid].set_leaf(node->info.leaf_value);
419  }
420  }
421  }
422  }
423  if (flag_leaf_vector == 0) {
424  if (model.num_output_group > 1) {
425  // multiclass classification with gradient boosted trees
426  CHECK_EARLY_RETURN(!model.random_forest_flag,
427  "To use a random forest for multi-class classification, each leaf "
428  "node must output a leaf vector specifying a probability "
429  "distribution");
430  CHECK_EARLY_RETURN(pimpl->trees.size() % model.num_output_group == 0,
431  "For multi-class classifiers with gradient boosted trees, the number "
432  "of trees must be evenly divisible by the number of output groups");
433  }
434  } else if (flag_leaf_vector == 1) {
435  // multiclass classification with a random forest
436  CHECK_EARLY_RETURN(model.random_forest_flag,
437  "In multi-class classifiers with gradient boosted trees, each leaf "
438  "node must output a single floating-point value.");
439  } else {
440  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
441  }
442  *out_model = std::move(model);
443  return true;
444 }
445 
446 } // namespace frontend
447 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:437
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree.h:311
thin wrapper for tree ensemble model
Definition: tree.h:427
tree builder class
Definition: frontend.h:70
std::vector< Tree > trees
member trees
Definition: tree.h:429
ModelParam param
extra parameters
Definition: tree.h:442
model structure for tree
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:261
in-memory representation of a decision tree
Definition: tree.h:22
bool DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:329
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
Definition: builder.cc:266
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:440
double tl_float
float type to be used internally
Definition: base.h:17
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:321
bool CommitModel(Model *out_model)
finalize the model and produce the in-memory representation
Definition: builder.cc:338
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:319
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:434
Operator
comparison operators
Definition: base.h:23