Treelite
builder.cc
Go to the documentation of this file.
1 
8 #include <dmlc/registry.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <memory>
12 #include <queue>
13 
14 /* data structures with underscore prefixes are internal use only and don't have external linkage */
15 namespace {
16 
17 struct NodeDraft {
18  enum class Status : int8_t {
19  kEmpty, kNumericalTest, kCategoricalTest, kLeaf
20  };
21  /*
22  * leaf vector: only used for random forests with multi-class classification
23  */
24  std::vector<treelite::frontend::Value> leaf_vector;
25  Status status;
26  /* pointers to parent, left and right children */
27  NodeDraft* parent;
28  NodeDraft* left_child;
29  NodeDraft* right_child;
30  // split feature index
31  unsigned feature_id;
32  // default direction for missing values
33  bool default_left;
34  // leaf value (only for leaf nodes)
35  treelite::frontend::Value leaf_value;
36  // threshold (only for non-leaf nodes)
37  treelite::frontend::Value threshold;
38  // (for numerical split)
39  // operator to use for expression of form [fval] OP [threshold]
40  // If the expression evaluates to true, take the left child;
41  // otherwise, take the right child.
43  // (for categorical split)
44  // list of all categories that belong to the left child node.
45  // All others not in the list belong to the right child node.
46  // Categories are integers ranging from 0 to (n-1), where n is the number of
47  // categories in that particular feature. Let's assume n <= 64.
48  std::vector<uint32_t> left_categories;
49 
50  inline NodeDraft()
51  : status(Status::kEmpty), parent(nullptr), left_child(nullptr), right_child(nullptr) {}
52 };
53 
54 struct TreeDraft {
55  NodeDraft* root;
56  std::unordered_map<int, std::unique_ptr<NodeDraft>> nodes;
57  treelite::TypeInfo threshold_type;
58  treelite::TypeInfo leaf_output_type;
59  inline TreeDraft(treelite::TypeInfo threshold_type, treelite::TypeInfo leaf_output_type)
60  : root(nullptr), nodes(), threshold_type(threshold_type), leaf_output_type(leaf_output_type) {}
61 };
62 
63 } // anonymous namespace
64 
65 namespace treelite {
66 namespace frontend {
67 
68 DMLC_REGISTRY_FILE_TAG(builder);
69 
71  TreeDraft tree;
72  inline TreeBuilderImpl(TypeInfo threshold_type, TypeInfo leaf_output_type)
73  : tree(threshold_type, leaf_output_type) {}
74 };
75 
77  std::vector<TreeBuilder> trees;
78  int num_feature;
79  int num_class;
80  bool average_tree_output;
81  TypeInfo threshold_type;
82  TypeInfo leaf_output_type;
83  std::vector<std::pair<std::string, std::string>> cfg;
84  inline ModelBuilderImpl(int num_feature, int num_class, bool average_tree_output,
85  TypeInfo threshold_type, TypeInfo leaf_output_type)
86  : trees(), num_feature(num_feature), num_class(num_class),
87  average_tree_output(average_tree_output), threshold_type(threshold_type),
88  leaf_output_type(leaf_output_type), cfg() {
89  CHECK_GT(num_feature, 0) << "ModelBuilder: num_feature must be positive";
90  CHECK_GT(num_class, 0) << "ModelBuilder: num_class must be positive";
91  CHECK(threshold_type != TypeInfo::kInvalid)
92  << "ModelBuilder: threshold_type can't be invalid";
93  CHECK(leaf_output_type != TypeInfo::kInvalid)
94  << "ModelBuilder: leaf_output_type can't be invalid";
95  }
96  // Templatized implementation of CommitModel()
97  template <typename ThresholdType, typename LeafOutputType>
98  void CommitModelImpl(ModelImpl<ThresholdType, LeafOutputType>* out_model);
99 };
100 
101 template <typename ThresholdType, typename LeafOutputType>
102 void SetLeafVector(Tree<ThresholdType, LeafOutputType>* tree, int nid,
103  const std::vector<Value>& leaf_vector) {
104  const size_t leaf_vector_size = leaf_vector.size();
105  const TypeInfo expected_leaf_type = TypeToInfo<LeafOutputType>();
106  std::vector<LeafOutputType> out_leaf_vector;
107  for (size_t i = 0; i < leaf_vector_size; ++i) {
108  const Value& leaf_value = leaf_vector[i];
109  CHECK(leaf_value.GetValueType() == expected_leaf_type)
110  << "Leaf value at index " << i << " has incorrect type. Expected: "
111  << TypeInfoToString(expected_leaf_type) << ", Given: "
112  << TypeInfoToString(leaf_value.GetValueType());
113  out_leaf_vector.push_back(leaf_value.Get<LeafOutputType>());
114  }
115  tree->SetLeafVector(nid, out_leaf_vector);
116 }
117 
118 Value::Value() : handle_(nullptr), type_(TypeInfo::kInvalid) {}
119 
120 template <typename T>
121 Value
122 Value::Create(T init_value) {
123  Value value;
124  std::unique_ptr<T> ptr = std::make_unique<T>(init_value);
125  value.handle_.reset(ptr.release());
126  value.type_ = TypeToInfo<T>();
127  return value;
128 }
129 
130 template <typename ValueType>
132  public:
133  inline static std::shared_ptr<void> Dispatch(const void* init_value) {
134  const auto* v_ptr = static_cast<const ValueType*>(init_value);
135  CHECK(v_ptr);
136  ValueType v = *v_ptr;
137  return std::make_shared<ValueType>(v);
138  }
139 };
140 
141 Value
142 Value::Create(const void* init_value, TypeInfo type) {
143  Value value;
144  CHECK(type != TypeInfo::kInvalid) << "Type must be valid";
145  value.type_ = type;
146  value.handle_ = DispatchWithTypeInfo<CreateHandle>(type, init_value);
147  return value;
148 }
149 
150 template <typename T>
151 T&
152 Value::Get() {
153  CHECK(handle_);
154  T* out = static_cast<T*>(handle_.get());
155  CHECK(out);
156  return *out;
157 }
158 
159 template <typename T>
160 const T&
161 Value::Get() const {
162  CHECK(handle_);
163  const T* out = static_cast<const T*>(handle_.get());
164  CHECK(out);
165  return *out;
166 }
167 
168 TypeInfo
169 Value::GetValueType() const {
170  return type_;
171 }
172 
173 TreeBuilder::TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type)
174  : pimpl_(new TreeBuilderImpl(threshold_type, leaf_output_type)), ensemble_id_(nullptr) {}
175 TreeBuilder::~TreeBuilder() = default;
176 
177 void
178 TreeBuilder::CreateNode(int node_key) {
179  auto& nodes = pimpl_->tree.nodes;
180  CHECK_EQ(nodes.count(node_key), 0) << "CreateNode: nodes with duplicate keys are not allowed";
181  nodes[node_key] = std::make_unique<NodeDraft>();
182 }
183 
184 void
185 TreeBuilder::DeleteNode(int node_key) {
186  auto& tree = pimpl_->tree;
187  auto& nodes = tree.nodes;
188  CHECK_GT(nodes.count(node_key), 0) << "DeleteNode: no node found with node_key";
189  NodeDraft* node = nodes[node_key].get();
190  if (tree.root == node) { // deleting root
191  tree.root = nullptr;
192  }
193  if (node->left_child != nullptr) { // deleting a parent
194  node->left_child->parent = nullptr;
195  }
196  if (node->right_child != nullptr) { // deleting a parent
197  node->right_child->parent = nullptr;
198  }
199  if (node == tree.root) { // deleting root
200  tree.root = nullptr;
201  nodes.clear();
202  } else {
203  nodes.erase(node_key);
204  }
205 }
206 
207 void
209  auto& tree = pimpl_->tree;
210  auto& nodes = tree.nodes;
211  CHECK_GT(nodes.count(node_key), 0) << "SetRootNode: no node found with node_key";
212  NodeDraft* node = nodes[node_key].get();
213  CHECK(!node->parent) << "SetRootNode: a root node cannot have a parent";
214  tree.root = node;
215 }
216 
217 void
218 TreeBuilder::SetNumericalTestNode(int node_key, unsigned feature_id, const char* opname,
219  Value threshold, bool default_left, int left_child_key,
220  int right_child_key) {
221  CHECK_GT(optable.count(opname), 0) << "No operator \"" << opname << "\" exists";
222  Operator op = optable.at(opname);
223  SetNumericalTestNode(node_key, feature_id, op, std::move(threshold), default_left,
224  left_child_key, right_child_key);
225 }
226 
227 void
228 TreeBuilder::SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold,
229  bool default_left, int left_child_key, int right_child_key) {
230  auto& tree = pimpl_->tree;
231  auto& nodes = tree.nodes;
232  CHECK(tree.threshold_type == threshold.GetValueType())
233  << "SetNumericalTestNode: threshold has an incorrect type. "
234  << "Expected: " << TypeInfoToString(tree.threshold_type)
235  << ", Given: " << TypeInfoToString(threshold.GetValueType());
236  CHECK_GT(nodes.count(node_key), 0) << "SetNumericalTestNode: no node found with node_key";
237  CHECK_GT(nodes.count(left_child_key), 0)
238  << "SetNumericalTestNode: no node found with left_child_key";
239  CHECK_GT(nodes.count(right_child_key), 0)
240  << "SetNumericalTestNode: no node found with right_child_key";
241  NodeDraft* node = nodes[node_key].get();
242  NodeDraft* left_child = nodes[left_child_key].get();
243  NodeDraft* right_child = nodes[right_child_key].get();
244  CHECK(node->status == NodeDraft::Status::kEmpty)
245  << "SetNumericalTestNode: cannot modify a non-empty node";
246  CHECK(!left_child->parent)
247  << "SetNumericalTestNode: node designated as left child already has a parent";
248  CHECK(!right_child->parent)
249  << "SetNumericalTestNode: node designated as right child already has a parent";
250  CHECK(left_child != tree.root && right_child != tree.root)
251  << "SetNumericalTestNode: the root node cannot be a child";
252  node->status = NodeDraft::Status::kNumericalTest;
253  node->left_child = nodes[left_child_key].get();
254  node->left_child->parent = node;
255  node->right_child = nodes[right_child_key].get();
256  node->right_child->parent = node;
257  node->feature_id = feature_id;
258  node->default_left = default_left;
259  node->threshold = std::move(threshold);
260  node->op = op;
261 }
262 
263 void
264 TreeBuilder::SetCategoricalTestNode(int node_key, unsigned feature_id,
265  const std::vector<uint32_t>& left_categories, bool default_left,
266  int left_child_key, int right_child_key) {
267  auto &tree = pimpl_->tree;
268  auto &nodes = tree.nodes;
269  CHECK_GT(nodes.count(node_key), 0) << "SetCategoricalTestNode: no node found with node_key";
270  CHECK_GT(nodes.count(left_child_key), 0)
271  << "SetCategoricalTestNode: no node found with left_child_key";
272  CHECK_GT(nodes.count(right_child_key), 0)
273  << "SetCategoricalTestNode: no node found with right_child_key";
274  NodeDraft* node = nodes[node_key].get();
275  NodeDraft* left_child = nodes[left_child_key].get();
276  NodeDraft* right_child = nodes[right_child_key].get();
277  CHECK(node->status == NodeDraft::Status::kEmpty)
278  << "SetCategoricalTestNode: cannot modify a non-empty node";
279  CHECK(!left_child->parent)
280  << "SetCategoricalTestNode: node designated as left child already has a parent";
281  CHECK(!right_child->parent)
282  << "SetCategoricalTestNode: node designated as right child already has a parent";
283  CHECK(left_child != tree.root && right_child != tree.root)
284  << "SetCategoricalTestNode: the root node cannot be a child";
285  node->status = NodeDraft::Status::kCategoricalTest;
286  node->left_child = nodes[left_child_key].get();
287  node->left_child->parent = node;
288  node->right_child = nodes[right_child_key].get();
289  node->right_child->parent = node;
290  node->feature_id = feature_id;
291  node->default_left = default_left;
292  node->left_categories = left_categories;
293 }
294 
295 void
296 TreeBuilder::SetLeafNode(int node_key, Value leaf_value) {
297  auto& tree = pimpl_->tree;
298  auto& nodes = tree.nodes;
299  CHECK(tree.leaf_output_type == leaf_value.GetValueType())
300  << "SetLeafNode: leaf_value has an incorrect type. "
301  << "Expected: " << TypeInfoToString(tree.leaf_output_type)
302  << ", Given: " << TypeInfoToString(leaf_value.GetValueType());
303  CHECK_GT(nodes.count(node_key), 0) << "SetLeafNode: no node found with node_key";
304  NodeDraft* node = nodes[node_key].get();
305  CHECK(node->status == NodeDraft::Status::kEmpty) << "SetLeafNode: cannot modify a non-empty node";
306  node->status = NodeDraft::Status::kLeaf;
307  node->leaf_value = std::move(leaf_value);
308 }
309 
310 void
311 TreeBuilder::SetLeafVectorNode(int node_key, const std::vector<Value>& leaf_vector) {
312  auto& tree = pimpl_->tree;
313  auto& nodes = tree.nodes;
314  const size_t leaf_vector_len = leaf_vector.size();
315  for (size_t i = 0; i < leaf_vector_len; ++i) {
316  const Value& leaf_value = leaf_vector[i];
317  CHECK(tree.leaf_output_type == leaf_value.GetValueType())
318  << "SetLeafVectorNode: the element " << i << " in leaf_vector has an incorrect type. "
319  << "Expected: " << TypeInfoToString(tree.leaf_output_type)
320  << ", Given: " << TypeInfoToString(leaf_value.GetValueType());
321  }
322  CHECK_GT(nodes.count(node_key), 0) << "SetLeafVectorNode: no node found with node_key";
323  NodeDraft* node = nodes[node_key].get();
324  CHECK(node->status == NodeDraft::Status::kEmpty)
325  << "SetLeafVectorNode: cannot modify a non-empty node";
326  node->status = NodeDraft::Status::kLeaf;
327  node->leaf_vector = leaf_vector;
328 }
329 
330 ModelBuilder::ModelBuilder(int num_feature, int num_class, bool average_tree_output,
331  TypeInfo threshold_type, TypeInfo leaf_output_type)
332  : pimpl_(new ModelBuilderImpl(num_feature, num_class, average_tree_output,
333  threshold_type, leaf_output_type)) {}
334 ModelBuilder::~ModelBuilder() = default;
335 
336 void
337 ModelBuilder::SetModelParam(const char* name, const char* value) {
338  pimpl_->cfg.emplace_back(name, value);
339 }
340 
341 int
342 ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) {
343  if (tree_builder == nullptr) {
344  LOG(FATAL) << "InsertTree: not a valid tree builder";
345  return -1;
346  }
347  if (tree_builder->ensemble_id_ != nullptr) {
348  LOG(FATAL) << "InsertTree: tree is already part of another ensemble";
349  return -1;
350  }
351  if (tree_builder->pimpl_->tree.threshold_type != this->pimpl_->threshold_type) {
352  LOG(FATAL)
353  << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all "
354  << "member trees to use " << TypeInfoToString(this->pimpl_->threshold_type)
355  << " type for split thresholds whereas the tree is using "
356  << TypeInfoToString(tree_builder->pimpl_->tree.threshold_type);
357  return -1;
358  }
359  if (tree_builder->pimpl_->tree.leaf_output_type != this->pimpl_->leaf_output_type) {
360  LOG(FATAL)
361  << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all "
362  << "member trees to use " << TypeInfoToString(this->pimpl_->leaf_output_type)
363  << " type for leaf outputs whereas the tree is using "
364  << TypeInfoToString(tree_builder->pimpl_->tree.leaf_output_type);
365  return -1;
366  }
367 
368  // check bounds for feature indices
369  for (const auto& kv : tree_builder->pimpl_->tree.nodes) {
370  const NodeDraft::Status status = kv.second->status;
371  if (status == NodeDraft::Status::kNumericalTest ||
372  status == NodeDraft::Status::kCategoricalTest) {
373  const int fid = static_cast<int>(kv.second->feature_id);
374  if (fid < 0 || fid >= this->pimpl_->num_feature) {
375  LOG(FATAL) << "InsertTree: tree has an invalid split at node "
376  << kv.first << ": feature id " << kv.second->feature_id << " is out of bound";
377  return -1;
378  }
379  }
380  }
381 
382  // perform insertion
383  auto& trees = pimpl_->trees;
384  if (index == -1) {
385  trees.push_back(std::move(*tree_builder));
386  tree_builder->ensemble_id_ = this;
387  return static_cast<int>(trees.size());
388  } else {
389  if (static_cast<size_t>(index) <= trees.size()) {
390  trees.insert(trees.begin() + index, std::move(*tree_builder));
391  tree_builder->ensemble_id_ = this;
392  return index;
393  } else {
394  LOG(FATAL) << "InsertTree: index out of bound";
395  return -1;
396  }
397  }
398 }
399 
402  return &pimpl_->trees.at(index);
403 }
404 
405 const TreeBuilder*
406 ModelBuilder::GetTree(int index) const {
407  return &pimpl_->trees.at(index);
408 }
409 
410 void
412  auto& trees = pimpl_->trees;
413  CHECK_LT(static_cast<size_t>(index), trees.size()) << "DeleteTree: index out of bound";
414  trees.erase(trees.begin() + index);
415 }
416 
417 std::unique_ptr<Model>
419  std::unique_ptr<Model> model_ptr = Model::Create(pimpl_->threshold_type,
420  pimpl_->leaf_output_type);
421  model_ptr->Dispatch([this](auto& model) {
422  this->pimpl_->CommitModelImpl(&model);
423  });
424  return model_ptr;
425 }
426 
427 template <typename ThresholdType, typename LeafOutputType>
428 void
429 ModelBuilderImpl::CommitModelImpl(ModelImpl<ThresholdType, LeafOutputType>* out_model) {
430  ModelImpl<ThresholdType, LeafOutputType>& model = *out_model;
431  model.num_feature = this->num_feature;
432  model.average_tree_output = this->average_tree_output;
433  model.task_param.output_type = TaskParameter::OutputType::kFloat;
434  model.task_param.num_class = this->num_class;
435  // extra parameters
436  InitParamAndCheck(&model.param, this->cfg);
437 
438  // flag to check consistent use of leaf vector
439  // 0: no leaf should use leaf vector
440  // 1: every leaf should use leaf vector
441  // -1: indeterminate
442  int8_t flag_leaf_vector = -1;
443 
444  for (const auto& tree_builder : this->trees) {
445  const auto& _tree = tree_builder.pimpl_->tree;
446  CHECK(_tree.root) << "CommitModel: a tree has no root node";
447  CHECK(_tree.root->status != NodeDraft::Status::kEmpty)
448  << "SetRootNode: cannot set an empty node as root";
449  model.trees.emplace_back();
450  Tree<ThresholdType, LeafOutputType>& tree = model.trees.back();
451  tree.Init();
452 
453  // assign node ID's so that a breadth-wise traversal would yield
454  // the monotonic sequence 0, 1, 2, ...
455  std::queue<std::pair<const NodeDraft*, int>> Q; // (internal pointer, ID)
456  Q.push({_tree.root, 0}); // assign 0 to root
457  while (!Q.empty()) {
458  const NodeDraft* node;
459  int nid;
460  std::tie(node, nid) = Q.front();
461  Q.pop();
462  CHECK(node->status != NodeDraft::Status::kEmpty)
463  << "CommitModel: encountered an empty node in the middle of a tree";
464  if (node->status == NodeDraft::Status::kNumericalTest) {
465  CHECK(node->left_child) << "CommitModel: a test node lacks a left child";
466  CHECK(node->right_child) << "CommitModel: a test node lacks a right child";
467  CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent";
468  CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent";
469  tree.AddChilds(nid);
470  CHECK(node->threshold.GetValueType() == TypeToInfo<ThresholdType>())
471  << "CommitModel: The specified threshold has incorrect type. Expected: "
472  << TypeInfoToString(TypeToInfo<ThresholdType>())
473  << " Given: " << TypeInfoToString(node->threshold.GetValueType());
474  ThresholdType threshold = node->threshold.Get<ThresholdType>();
475  tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op);
476  Q.push({node->left_child, tree.LeftChild(nid)});
477  Q.push({node->right_child, tree.RightChild(nid)});
478  } else if (node->status == NodeDraft::Status::kCategoricalTest) {
479  CHECK(node->left_child) << "CommitModel: a test node lacks a left child";
480  CHECK(node->right_child) << "CommitModel: a test node lacks a right child";
481  CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent";
482  CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent";
483  tree.AddChilds(nid);
484  tree.SetCategoricalSplit(nid, node->feature_id, node->default_left, node->left_categories,
485  false);
486  Q.push({node->left_child, tree.LeftChild(nid)});
487  Q.push({node->right_child, tree.RightChild(nid)});
488  } else { // leaf node
489  CHECK(node->left_child == nullptr && node->right_child == nullptr)
490  << "CommitModel: a leaf node cannot have children";
491  if (!node->leaf_vector.empty()) { // leaf vector exists
492  CHECK_NE(flag_leaf_vector, 0)
493  << "CommitModel: Inconsistent use of leaf vector: if one leaf node uses a leaf vector, "
494  << "*every* leaf node must use a leaf vector";
495  flag_leaf_vector = 1; // now every leaf must use leaf vector
496  CHECK_EQ(node->leaf_vector.size(), model.task_param.num_class)
497  << "CommitModel: The length of leaf vector must be identical to the number of output "
498  << "groups";
499  SetLeafVector(&tree, nid, node->leaf_vector);
500  } else { // ordinary leaf
501  CHECK_NE(flag_leaf_vector, 1)
502  << "CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf "
503  << "vector, *no other* leaf node can use a leaf vector";
504  flag_leaf_vector = 0; // now no leaf can use leaf vector
505  CHECK(node->leaf_value.GetValueType() == TypeToInfo<LeafOutputType>())
506  << "CommitModel: The specified leaf value has incorrect type. Expected: "
507  << TypeInfoToString(TypeToInfo<LeafOutputType>())
508  << " Given: " << TypeInfoToString(node->leaf_value.GetValueType());
509  LeafOutputType leaf_value = node->leaf_value.Get<LeafOutputType>();
510  tree.SetLeaf(nid, leaf_value);
511  }
512  }
513  }
514  }
515  if (flag_leaf_vector == 0) {
516  model.task_param.leaf_vector_size = 1;
517  if (model.task_param.num_class > 1) {
518  // multi-class classifier, XGBoost/LightGBM style
519  model.task_type = TaskType::kMultiClfGrovePerClass;
520  model.task_param.grove_per_class = true;
521  CHECK_EQ(this->trees.size() % model.task_param.num_class, 0)
522  << "For multi-class classifiers with gradient boosted trees, the number of trees must be "
523  << "evenly divisible by the number of output groups";
524  } else {
525  // binary classifier or regressor
526  model.task_type = TaskType::kBinaryClfRegr;
527  model.task_param.grove_per_class = false;
528  }
529  } else if (flag_leaf_vector == 1) {
530  // multi-class classifier, sklearn RF style
531  model.task_type = TaskType::kMultiClfProbDistLeaf;
532  model.task_param.grove_per_class = false;
533  CHECK_GT(model.task_param.num_class, 1) << "Expected leaf vectors with length exceeding 1";
534  model.task_param.leaf_vector_size = model.task_param.num_class;
535  } else {
536  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
537  }
538 }
539 
540 template Value Value::Create(uint32_t init_value);
541 template Value Value::Create(float init_value);
542 template Value Value::Create(double init_value);
543 
544 } // namespace frontend
545 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:658
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Definition: builder.cc:418
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:521
void DeleteNode(int node_key)
Remove a node from a tree.
Definition: builder.cc:185
tree builder class
Definition: frontend.h:96
bool average_tree_output
whether to average tree outputs
Definition: tree.h:654
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
Definition: builder.cc:218
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
Definition: builder.cc:337
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
Definition: builder.cc:296
void SetRootNode(int node_key)
Set a node as the root of a tree.
Definition: builder.cc:208
in-memory representation of a decision tree
Definition: tree.h:191
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
Definition: builder.cc:401
void SetLeafVector(int nid, const std::vector< LeafOutputType > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
Definition: tree_impl.h:632
TaskType task_type
Task type.
Definition: tree.h:652
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:673
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
Definition: builder.cc:342
void SetCategoricalTestNode(int node_key, unsigned feature_id, const std::vector< uint32_t > &left_categories, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a categorical test node. A list defines all categories that would be classifi...
Definition: builder.cc:264
TaskParameter task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:656
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
void DeleteTree(int index)
Remove a tree from the ensemble.
Definition: builder.cc:411
const std::unordered_map< std::string, Operator > optable
conversion table from string to Operator, defined in tables.cc
Definition: optable.cc:14
int LeftChild(int nid) const
Getters.
Definition: tree.h:309
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:316
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:534
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
Definition: tree_impl.h:588
void CreateNode(int node_key)
Create an empty node within a tree.
Definition: builder.cc:178
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:650
void SetLeafVectorNode(int node_key, const std::vector< Value > &leaf_vector)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
Definition: builder.cc:311
ModelBuilder(int num_feature, int num_class, bool average_tree_output, TypeInfo threshold_type, TypeInfo leaf_output_type)
Constructor.
Definition: builder.cc:330
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:572
Operator
comparison operators
Definition: base.h:26
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:622