Treelite
builder.cc
Go to the documentation of this file.
1 
8 #include <treelite/frontend.h>
9 #include <treelite/tree.h>
10 #include <treelite/logging.h>
11 #include <memory>
12 #include <queue>
13 
14 /* data structures with underscore prefixes are internal use only and don't have external linkage */
15 namespace {
16 
17 struct NodeDraft {
18  enum class Status : int8_t {
19  kEmpty, kNumericalTest, kCategoricalTest, kLeaf
20  };
21  /*
22  * leaf vector: only used for random forests with multi-class classification
23  */
24  std::vector<treelite::frontend::Value> leaf_vector;
25  Status status;
26  /* pointers to parent, left and right children */
27  NodeDraft* parent;
28  NodeDraft* left_child;
29  NodeDraft* right_child;
30  // split feature index
31  unsigned feature_id;
32  // default direction for missing values
33  bool default_left;
34  // leaf value (only for leaf nodes)
35  treelite::frontend::Value leaf_value;
36  // threshold (only for non-leaf nodes)
37  treelite::frontend::Value threshold;
38  // (for numerical split)
39  // operator to use for expression of form [fval] OP [threshold]
40  // If the expression evaluates to true, take the left child;
41  // otherwise, take the right child.
43  // (for categorical split)
44  // list of all categories that belong to the left child node.
45  // All others not in the list belong to the right child node.
46  // Categories are integers ranging from 0 to (n-1), where n is the number of
47  // categories in that particular feature. Let's assume n <= 64.
48  std::vector<uint32_t> left_categories;
49 
50  inline NodeDraft()
51  : status(Status::kEmpty), parent(nullptr), left_child(nullptr), right_child(nullptr) {}
52 };
53 
54 struct TreeDraft {
55  NodeDraft* root;
56  std::unordered_map<int, std::unique_ptr<NodeDraft>> nodes;
57  treelite::TypeInfo threshold_type;
58  treelite::TypeInfo leaf_output_type;
59  inline TreeDraft(treelite::TypeInfo threshold_type, treelite::TypeInfo leaf_output_type)
60  : root(nullptr), nodes(), threshold_type(threshold_type), leaf_output_type(leaf_output_type) {}
61 };
62 
63 } // anonymous namespace
64 
65 namespace treelite {
66 namespace frontend {
67 
69  TreeDraft tree;
70  inline TreeBuilderImpl(TypeInfo threshold_type, TypeInfo leaf_output_type)
71  : tree(threshold_type, leaf_output_type) {}
72 };
73 
75  std::vector<TreeBuilder> trees;
76  int num_feature;
77  int num_class;
78  bool average_tree_output;
79  TypeInfo threshold_type;
80  TypeInfo leaf_output_type;
81  std::vector<std::pair<std::string, std::string>> cfg;
82  inline ModelBuilderImpl(int num_feature, int num_class, bool average_tree_output,
83  TypeInfo threshold_type, TypeInfo leaf_output_type)
84  : trees(), num_feature(num_feature), num_class(num_class),
85  average_tree_output(average_tree_output), threshold_type(threshold_type),
86  leaf_output_type(leaf_output_type), cfg() {
87  TREELITE_CHECK_GT(num_feature, 0) << "ModelBuilder: num_feature must be positive";
88  TREELITE_CHECK_GT(num_class, 0) << "ModelBuilder: num_class must be positive";
89  TREELITE_CHECK(threshold_type != TypeInfo::kInvalid)
90  << "ModelBuilder: threshold_type can't be invalid";
91  TREELITE_CHECK(leaf_output_type != TypeInfo::kInvalid)
92  << "ModelBuilder: leaf_output_type can't be invalid";
93  }
94  // Templatized implementation of CommitModel()
95  template <typename ThresholdType, typename LeafOutputType>
96  void CommitModelImpl(ModelImpl<ThresholdType, LeafOutputType>* out_model);
97 };
98 
99 template <typename ThresholdType, typename LeafOutputType>
100 void SetLeafVector(Tree<ThresholdType, LeafOutputType>* tree, int nid,
101  const std::vector<Value>& leaf_vector) {
102  const size_t leaf_vector_size = leaf_vector.size();
103  const TypeInfo expected_leaf_type = TypeToInfo<LeafOutputType>();
104  std::vector<LeafOutputType> out_leaf_vector;
105  for (size_t i = 0; i < leaf_vector_size; ++i) {
106  const Value& leaf_value = leaf_vector[i];
107  TREELITE_CHECK(leaf_value.GetValueType() == expected_leaf_type)
108  << "Leaf value at index " << i << " has incorrect type. Expected: "
109  << TypeInfoToString(expected_leaf_type) << ", Given: "
110  << TypeInfoToString(leaf_value.GetValueType());
111  out_leaf_vector.push_back(leaf_value.Get<LeafOutputType>());
112  }
113  tree->SetLeafVector(nid, out_leaf_vector);
114 }
115 
116 Value::Value() : handle_(nullptr), type_(TypeInfo::kInvalid) {}
117 
118 template <typename T>
119 Value
120 Value::Create(T init_value) {
121  Value value;
122  std::unique_ptr<T> ptr = std::make_unique<T>(init_value);
123  value.handle_.reset(ptr.release());
124  value.type_ = TypeToInfo<T>();
125  return value;
126 }
127 
128 template <typename ValueType>
130  public:
131  inline static std::shared_ptr<void> Dispatch(const void* init_value) {
132  const auto* v_ptr = static_cast<const ValueType*>(init_value);
133  TREELITE_CHECK(v_ptr);
134  ValueType v = *v_ptr;
135  return std::make_shared<ValueType>(v);
136  }
137 };
138 
139 Value
140 Value::Create(const void* init_value, TypeInfo type) {
141  Value value;
142  TREELITE_CHECK(type != TypeInfo::kInvalid) << "Type must be valid";
143  value.type_ = type;
144  value.handle_ = DispatchWithTypeInfo<CreateHandle>(type, init_value);
145  return value;
146 }
147 
148 template <typename T>
149 T&
150 Value::Get() {
151  TREELITE_CHECK(handle_);
152  T* out = static_cast<T*>(handle_.get());
153  TREELITE_CHECK(out);
154  return *out;
155 }
156 
157 template <typename T>
158 const T&
159 Value::Get() const {
160  TREELITE_CHECK(handle_);
161  const T* out = static_cast<const T*>(handle_.get());
162  TREELITE_CHECK(out);
163  return *out;
164 }
165 
166 TypeInfo
167 Value::GetValueType() const {
168  return type_;
169 }
170 
171 TreeBuilder::TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type)
172  : pimpl_(new TreeBuilderImpl(threshold_type, leaf_output_type)), ensemble_id_(nullptr) {}
173 TreeBuilder::~TreeBuilder() = default;
174 
175 void
176 TreeBuilder::CreateNode(int node_key) {
177  auto& nodes = pimpl_->tree.nodes;
178  TREELITE_CHECK_EQ(nodes.count(node_key), 0)
179  << "CreateNode: nodes with duplicate keys are not allowed";
180  nodes[node_key] = std::make_unique<NodeDraft>();
181 }
182 
183 void
184 TreeBuilder::DeleteNode(int node_key) {
185  auto& tree = pimpl_->tree;
186  auto& nodes = tree.nodes;
187  TREELITE_CHECK_GT(nodes.count(node_key), 0) << "DeleteNode: no node found with node_key";
188  NodeDraft* node = nodes[node_key].get();
189  if (tree.root == node) { // deleting root
190  tree.root = nullptr;
191  }
192  if (node->left_child != nullptr) { // deleting a parent
193  node->left_child->parent = nullptr;
194  }
195  if (node->right_child != nullptr) { // deleting a parent
196  node->right_child->parent = nullptr;
197  }
198  if (node == tree.root) { // deleting root
199  tree.root = nullptr;
200  nodes.clear();
201  } else {
202  nodes.erase(node_key);
203  }
204 }
205 
206 void
208  auto& tree = pimpl_->tree;
209  auto& nodes = tree.nodes;
210  TREELITE_CHECK_GT(nodes.count(node_key), 0) << "SetRootNode: no node found with node_key";
211  NodeDraft* node = nodes[node_key].get();
212  TREELITE_CHECK(!node->parent) << "SetRootNode: a root node cannot have a parent";
213  tree.root = node;
214 }
215 
216 void
217 TreeBuilder::SetNumericalTestNode(int node_key, unsigned feature_id, const char* opname,
218  Value threshold, bool default_left, int left_child_key,
219  int right_child_key) {
220  TREELITE_CHECK_GT(optable.count(opname), 0) << "No operator \"" << opname << "\" exists";
221  Operator op = optable.at(opname);
222  SetNumericalTestNode(node_key, feature_id, op, std::move(threshold), default_left,
223  left_child_key, right_child_key);
224 }
225 
226 void
227 TreeBuilder::SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold,
228  bool default_left, int left_child_key, int right_child_key) {
229  auto& tree = pimpl_->tree;
230  auto& nodes = tree.nodes;
231  TREELITE_CHECK(tree.threshold_type == threshold.GetValueType())
232  << "SetNumericalTestNode: threshold has an incorrect type. "
233  << "Expected: " << TypeInfoToString(tree.threshold_type)
234  << ", Given: " << TypeInfoToString(threshold.GetValueType());
235  TREELITE_CHECK_GT(nodes.count(node_key), 0)
236  << "SetNumericalTestNode: no node found with node_key";
237  TREELITE_CHECK_GT(nodes.count(left_child_key), 0)
238  << "SetNumericalTestNode: no node found with left_child_key";
239  TREELITE_CHECK_GT(nodes.count(right_child_key), 0)
240  << "SetNumericalTestNode: no node found with right_child_key";
241  NodeDraft* node = nodes[node_key].get();
242  NodeDraft* left_child = nodes[left_child_key].get();
243  NodeDraft* right_child = nodes[right_child_key].get();
244  TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
245  << "SetNumericalTestNode: cannot modify a non-empty node";
246  TREELITE_CHECK(!left_child->parent)
247  << "SetNumericalTestNode: node designated as left child already has a parent";
248  TREELITE_CHECK(!right_child->parent)
249  << "SetNumericalTestNode: node designated as right child already has a parent";
250  TREELITE_CHECK(left_child != tree.root && right_child != tree.root)
251  << "SetNumericalTestNode: the root node cannot be a child";
252  node->status = NodeDraft::Status::kNumericalTest;
253  node->left_child = nodes[left_child_key].get();
254  node->left_child->parent = node;
255  node->right_child = nodes[right_child_key].get();
256  node->right_child->parent = node;
257  node->feature_id = feature_id;
258  node->default_left = default_left;
259  node->threshold = std::move(threshold);
260  node->op = op;
261 }
262 
263 void
264 TreeBuilder::SetCategoricalTestNode(int node_key, unsigned feature_id,
265  const std::vector<uint32_t>& left_categories, bool default_left,
266  int left_child_key, int right_child_key) {
267  auto &tree = pimpl_->tree;
268  auto &nodes = tree.nodes;
269  TREELITE_CHECK_GT(nodes.count(node_key), 0)
270  << "SetCategoricalTestNode: no node found with node_key";
271  TREELITE_CHECK_GT(nodes.count(left_child_key), 0)
272  << "SetCategoricalTestNode: no node found with left_child_key";
273  TREELITE_CHECK_GT(nodes.count(right_child_key), 0)
274  << "SetCategoricalTestNode: no node found with right_child_key";
275  NodeDraft* node = nodes[node_key].get();
276  NodeDraft* left_child = nodes[left_child_key].get();
277  NodeDraft* right_child = nodes[right_child_key].get();
278  TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
279  << "SetCategoricalTestNode: cannot modify a non-empty node";
280  TREELITE_CHECK(!left_child->parent)
281  << "SetCategoricalTestNode: node designated as left child already has a parent";
282  TREELITE_CHECK(!right_child->parent)
283  << "SetCategoricalTestNode: node designated as right child already has a parent";
284  TREELITE_CHECK(left_child != tree.root && right_child != tree.root)
285  << "SetCategoricalTestNode: the root node cannot be a child";
286  node->status = NodeDraft::Status::kCategoricalTest;
287  node->left_child = nodes[left_child_key].get();
288  node->left_child->parent = node;
289  node->right_child = nodes[right_child_key].get();
290  node->right_child->parent = node;
291  node->feature_id = feature_id;
292  node->default_left = default_left;
293  node->left_categories = left_categories;
294 }
295 
296 void
297 TreeBuilder::SetLeafNode(int node_key, Value leaf_value) {
298  auto& tree = pimpl_->tree;
299  auto& nodes = tree.nodes;
300  TREELITE_CHECK(tree.leaf_output_type == leaf_value.GetValueType())
301  << "SetLeafNode: leaf_value has an incorrect type. "
302  << "Expected: " << TypeInfoToString(tree.leaf_output_type)
303  << ", Given: " << TypeInfoToString(leaf_value.GetValueType());
304  TREELITE_CHECK_GT(nodes.count(node_key), 0) << "SetLeafNode: no node found with node_key";
305  NodeDraft* node = nodes[node_key].get();
306  TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
307  << "SetLeafNode: cannot modify a non-empty node";
308  node->status = NodeDraft::Status::kLeaf;
309  node->leaf_value = std::move(leaf_value);
310 }
311 
312 void
313 TreeBuilder::SetLeafVectorNode(int node_key, const std::vector<Value>& leaf_vector) {
314  auto& tree = pimpl_->tree;
315  auto& nodes = tree.nodes;
316  const size_t leaf_vector_len = leaf_vector.size();
317  for (size_t i = 0; i < leaf_vector_len; ++i) {
318  const Value& leaf_value = leaf_vector[i];
319  TREELITE_CHECK(tree.leaf_output_type == leaf_value.GetValueType())
320  << "SetLeafVectorNode: the element " << i << " in leaf_vector has an incorrect type. "
321  << "Expected: " << TypeInfoToString(tree.leaf_output_type)
322  << ", Given: " << TypeInfoToString(leaf_value.GetValueType());
323  }
324  TREELITE_CHECK_GT(nodes.count(node_key), 0)
325  << "SetLeafVectorNode: no node found with node_key";
326  NodeDraft* node = nodes[node_key].get();
327  TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
328  << "SetLeafVectorNode: cannot modify a non-empty node";
329  node->status = NodeDraft::Status::kLeaf;
330  node->leaf_vector = leaf_vector;
331 }
332 
333 ModelBuilder::ModelBuilder(int num_feature, int num_class, bool average_tree_output,
334  TypeInfo threshold_type, TypeInfo leaf_output_type)
335  : pimpl_(new ModelBuilderImpl(num_feature, num_class, average_tree_output,
336  threshold_type, leaf_output_type)) {}
337 ModelBuilder::~ModelBuilder() = default;
338 
339 void
340 ModelBuilder::SetModelParam(const char* name, const char* value) {
341  pimpl_->cfg.emplace_back(name, value);
342 }
343 
344 int
345 ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) {
346  if (tree_builder == nullptr) {
347  TREELITE_LOG(FATAL) << "InsertTree: not a valid tree builder";
348  return -1;
349  }
350  if (tree_builder->ensemble_id_ != nullptr) {
351  TREELITE_LOG(FATAL) << "InsertTree: tree is already part of another ensemble";
352  return -1;
353  }
354  if (tree_builder->pimpl_->tree.threshold_type != this->pimpl_->threshold_type) {
355  TREELITE_LOG(FATAL)
356  << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all "
357  << "member trees to use " << TypeInfoToString(this->pimpl_->threshold_type)
358  << " type for split thresholds whereas the tree is using "
359  << TypeInfoToString(tree_builder->pimpl_->tree.threshold_type);
360  return -1;
361  }
362  if (tree_builder->pimpl_->tree.leaf_output_type != this->pimpl_->leaf_output_type) {
363  TREELITE_LOG(FATAL)
364  << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all "
365  << "member trees to use " << TypeInfoToString(this->pimpl_->leaf_output_type)
366  << " type for leaf outputs whereas the tree is using "
367  << TypeInfoToString(tree_builder->pimpl_->tree.leaf_output_type);
368  return -1;
369  }
370 
371  // check bounds for feature indices
372  for (const auto& kv : tree_builder->pimpl_->tree.nodes) {
373  const NodeDraft::Status status = kv.second->status;
374  if (status == NodeDraft::Status::kNumericalTest ||
375  status == NodeDraft::Status::kCategoricalTest) {
376  const int fid = static_cast<int>(kv.second->feature_id);
377  if (fid < 0 || fid >= this->pimpl_->num_feature) {
378  TREELITE_LOG(FATAL) << "InsertTree: tree has an invalid split at node "
379  << kv.first << ": feature id "
380  << kv.second->feature_id << " is out of bound";
381  return -1;
382  }
383  }
384  }
385 
386  // perform insertion
387  auto& trees = pimpl_->trees;
388  if (index == -1) {
389  trees.push_back(std::move(*tree_builder));
390  tree_builder->ensemble_id_ = this;
391  return static_cast<int>(trees.size());
392  } else {
393  if (static_cast<size_t>(index) <= trees.size()) {
394  trees.insert(trees.begin() + index, std::move(*tree_builder));
395  tree_builder->ensemble_id_ = this;
396  return index;
397  } else {
398  TREELITE_LOG(FATAL) << "InsertTree: index out of bound";
399  return -1;
400  }
401  }
402 }
403 
406  return &pimpl_->trees.at(index);
407 }
408 
409 const TreeBuilder*
410 ModelBuilder::GetTree(int index) const {
411  return &pimpl_->trees.at(index);
412 }
413 
414 void
416  auto& trees = pimpl_->trees;
417  TREELITE_CHECK_LT(static_cast<size_t>(index), trees.size())
418  << "DeleteTree: index out of bound";
419  trees.erase(trees.begin() + index);
420 }
421 
422 std::unique_ptr<Model>
424  std::unique_ptr<Model> model_ptr = Model::Create(pimpl_->threshold_type,
425  pimpl_->leaf_output_type);
426  model_ptr->Dispatch([this](auto& model) {
427  this->pimpl_->CommitModelImpl(&model);
428  });
429  return model_ptr;
430 }
431 
432 template <typename ThresholdType, typename LeafOutputType>
433 void
434 ModelBuilderImpl::CommitModelImpl(ModelImpl<ThresholdType, LeafOutputType>* out_model) {
435  ModelImpl<ThresholdType, LeafOutputType>& model = *out_model;
436  model.num_feature = this->num_feature;
437  model.average_tree_output = this->average_tree_output;
438  model.task_param.output_type = TaskParam::OutputType::kFloat;
439  model.task_param.num_class = this->num_class;
440  // extra parameters
441  InitParamAndCheck(&model.param, this->cfg);
442 
443  // flag to check consistent use of leaf vector
444  // 0: no leaf should use leaf vector
445  // 1: every leaf should use leaf vector
446  // -1: indeterminate
447  int8_t flag_leaf_vector = -1;
448 
449  for (const auto& tree_builder : this->trees) {
450  const auto& _tree = tree_builder.pimpl_->tree;
451  TREELITE_CHECK(_tree.root) << "CommitModel: a tree has no root node";
452  TREELITE_CHECK(_tree.root->status != NodeDraft::Status::kEmpty)
453  << "SetRootNode: cannot set an empty node as root";
454  model.trees.emplace_back();
455  Tree<ThresholdType, LeafOutputType>& tree = model.trees.back();
456  tree.Init();
457 
458  // assign node ID's so that a breadth-wise traversal would yield
459  // the monotonic sequence 0, 1, 2, ...
460  std::queue<std::pair<const NodeDraft*, int>> Q; // (internal pointer, ID)
461  Q.push({_tree.root, 0}); // assign 0 to root
462  while (!Q.empty()) {
463  const NodeDraft* node;
464  int nid;
465  std::tie(node, nid) = Q.front();
466  Q.pop();
467  TREELITE_CHECK(node->status != NodeDraft::Status::kEmpty)
468  << "CommitModel: encountered an empty node in the middle of a tree";
469  if (node->status == NodeDraft::Status::kNumericalTest) {
470  TREELITE_CHECK(node->left_child) << "CommitModel: a test node lacks a left child";
471  TREELITE_CHECK(node->right_child) << "CommitModel: a test node lacks a right child";
472  TREELITE_CHECK(node->left_child->parent == node)
473  << "CommitModel: left child has wrong parent";
474  TREELITE_CHECK(node->right_child->parent == node)
475  << "CommitModel: right child has wrong parent";
476  tree.AddChilds(nid);
477  TREELITE_CHECK(node->threshold.GetValueType() == TypeToInfo<ThresholdType>())
478  << "CommitModel: The specified threshold has incorrect type. Expected: "
479  << TypeInfoToString(TypeToInfo<ThresholdType>())
480  << " Given: " << TypeInfoToString(node->threshold.GetValueType());
481  ThresholdType threshold = node->threshold.Get<ThresholdType>();
482  tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op);
483  Q.push({node->left_child, tree.LeftChild(nid)});
484  Q.push({node->right_child, tree.RightChild(nid)});
485  } else if (node->status == NodeDraft::Status::kCategoricalTest) {
486  TREELITE_CHECK(node->left_child) << "CommitModel: a test node lacks a left child";
487  TREELITE_CHECK(node->right_child) << "CommitModel: a test node lacks a right child";
488  TREELITE_CHECK(node->left_child->parent == node)
489  << "CommitModel: left child has wrong parent";
490  TREELITE_CHECK(node->right_child->parent == node)
491  << "CommitModel: right child has wrong parent";
492  tree.AddChilds(nid);
493  tree.SetCategoricalSplit(nid, node->feature_id, node->default_left, node->left_categories,
494  false);
495  Q.push({node->left_child, tree.LeftChild(nid)});
496  Q.push({node->right_child, tree.RightChild(nid)});
497  } else { // leaf node
498  TREELITE_CHECK(node->left_child == nullptr && node->right_child == nullptr)
499  << "CommitModel: a leaf node cannot have children";
500  if (!node->leaf_vector.empty()) { // leaf vector exists
501  TREELITE_CHECK_NE(flag_leaf_vector, 0)
502  << "CommitModel: Inconsistent use of leaf vector: if one leaf node uses a leaf vector, "
503  << "*every* leaf node must use a leaf vector";
504  flag_leaf_vector = 1; // now every leaf must use leaf vector
505  TREELITE_CHECK_EQ(node->leaf_vector.size(), model.task_param.num_class)
506  << "CommitModel: The length of leaf vector must be identical to the number of output "
507  << "groups";
508  SetLeafVector(&tree, nid, node->leaf_vector);
509  } else { // ordinary leaf
510  TREELITE_CHECK_NE(flag_leaf_vector, 1)
511  << "CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf "
512  << "vector, *no other* leaf node can use a leaf vector";
513  flag_leaf_vector = 0; // now no leaf can use leaf vector
514  TREELITE_CHECK(node->leaf_value.GetValueType() == TypeToInfo<LeafOutputType>())
515  << "CommitModel: The specified leaf value has incorrect type. Expected: "
516  << TypeInfoToString(TypeToInfo<LeafOutputType>())
517  << " Given: " << TypeInfoToString(node->leaf_value.GetValueType());
518  LeafOutputType leaf_value = node->leaf_value.Get<LeafOutputType>();
519  tree.SetLeaf(nid, leaf_value);
520  }
521  }
522  }
523  }
524  if (flag_leaf_vector == 0) {
525  model.task_param.leaf_vector_size = 1;
526  if (model.task_param.num_class > 1) {
527  // multi-class classifier, XGBoost/LightGBM style
528  model.task_type = TaskType::kMultiClfGrovePerClass;
529  model.task_param.grove_per_class = true;
530  TREELITE_CHECK_EQ(this->trees.size() % model.task_param.num_class, 0)
531  << "For multi-class classifiers with gradient boosted trees, the number of trees must be "
532  << "evenly divisible by the number of output groups";
533  } else {
534  // binary classifier or regressor
535  model.task_type = TaskType::kBinaryClfRegr;
536  model.task_param.grove_per_class = false;
537  }
538  } else if (flag_leaf_vector == 1) {
539  // multi-class classifier, sklearn RF style
540  model.task_type = TaskType::kMultiClfProbDistLeaf;
541  model.task_param.grove_per_class = false;
542  TREELITE_CHECK_GT(model.task_param.num_class, 1)
543  << "Expected leaf vectors with length exceeding 1";
544  model.task_param.leaf_vector_size = model.task_param.num_class;
545  } else {
546  TREELITE_LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
547  }
548 }
549 
550 template Value Value::Create(uint32_t init_value);
551 template Value Value::Create(float init_value);
552 template Value Value::Create(double init_value);
553 
554 } // namespace frontend
555 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:702
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Definition: builder.cc:423
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:637
void DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:184
tree builder class
Definition: frontend.h:210
bool average_tree_output
whether to average tree outputs
Definition: tree.h:698
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:217
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:340
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:297
void SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:207
in-memory representation of a decision tree
Definition: tree.h:213
logging facility for Treelite
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:405
void SetLeafVector(int nid, const std::vector< LeafOutputType > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree_impl.h:751
TaskType task_type
Task type.
Definition: tree.h:696
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:726
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
Definition: builder.cc:345
void SetCategoricalTestNode(int node_key, unsigned feature_id, const std::vector< uint32_t > &left_categories, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a categorical test node. A list defines all categories that would be classifi...
Definition: builder.cc:264
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
void DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:415
const std::unordered_map< std::string, Operator > optable
conversion table from string to Operator, defined in tables.cc
Definition: optable.cc:14
int LeftChild(int nid) const
Getters.
Definition: tree.h:351
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:700
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:358
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:651
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
Definition: tree_impl.h:705
void CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:176
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:694
void SetLeafVectorNode(int node_key, const std::vector< Value > &leaf_vector)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: builder.cc:313
ModelBuilder(int num_feature, int num_class, bool average_tree_output, TypeInfo threshold_type, TypeInfo leaf_output_type)
Constructor.
Definition: builder.cc:333
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:689
Operator
comparison operators
Definition: base.h:26
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:741