18 enum class Status : int8_t {
19 kEmpty, kNumericalTest, kCategoricalTest, kLeaf
24 std::vector<treelite::frontend::Value> leaf_vector;
28 NodeDraft* left_child;
29 NodeDraft* right_child;
48 std::vector<uint32_t> left_categories;
51 : status(Status::kEmpty), parent(nullptr), left_child(nullptr), right_child(nullptr) {}
56 std::unordered_map<int, std::unique_ptr<NodeDraft>> nodes;
60 : root(nullptr), nodes(), threshold_type(threshold_type), leaf_output_type(leaf_output_type) {}
71 : tree(threshold_type, leaf_output_type) {}
75 std::vector<TreeBuilder> trees;
78 bool average_tree_output;
81 std::vector<std::pair<std::string, std::string>> cfg;
82 inline ModelBuilderImpl(
int num_feature,
int num_class,
bool average_tree_output,
84 : trees(), num_feature(num_feature), num_class(num_class),
85 average_tree_output(average_tree_output), threshold_type(threshold_type),
86 leaf_output_type(leaf_output_type), cfg() {
87 TREELITE_CHECK_GT(num_feature, 0) <<
"ModelBuilder: num_feature must be positive";
88 TREELITE_CHECK_GT(num_class, 0) <<
"ModelBuilder: num_class must be positive";
89 TREELITE_CHECK(threshold_type != TypeInfo::kInvalid)
90 <<
"ModelBuilder: threshold_type can't be invalid";
91 TREELITE_CHECK(leaf_output_type != TypeInfo::kInvalid)
92 <<
"ModelBuilder: leaf_output_type can't be invalid";
95 template <
typename ThresholdType,
typename LeafOutputType>
99 template <
typename ThresholdType,
typename LeafOutputType>
101 const std::vector<Value>& leaf_vector) {
102 const size_t leaf_vector_size = leaf_vector.size();
103 const TypeInfo expected_leaf_type = TypeToInfo<LeafOutputType>();
104 std::vector<LeafOutputType> out_leaf_vector;
105 for (
size_t i = 0; i < leaf_vector_size; ++i) {
106 const Value& leaf_value = leaf_vector[i];
107 TREELITE_CHECK(leaf_value.GetValueType() == expected_leaf_type)
108 <<
"Leaf value at index " << i <<
" has incorrect type. Expected: " 111 out_leaf_vector.push_back(leaf_value.Get<LeafOutputType>());
116 Value::Value() : handle_(
nullptr), type_(TypeInfo::kInvalid) {}
118 template <
typename T>
120 Value::Create(T init_value) {
122 std::unique_ptr<T> ptr = std::make_unique<T>(init_value);
123 value.handle_.reset(ptr.release());
124 value.type_ = TypeToInfo<T>();
128 template <
typename ValueType>
131 inline static std::shared_ptr<void> Dispatch(
const void* init_value) {
132 const auto* v_ptr =
static_cast<const ValueType*
>(init_value);
133 TREELITE_CHECK(v_ptr);
134 ValueType v = *v_ptr;
135 return std::make_shared<ValueType>(v);
140 Value::Create(
const void* init_value,
TypeInfo type) {
142 TREELITE_CHECK(type != TypeInfo::kInvalid) <<
"Type must be valid";
144 value.handle_ = DispatchWithTypeInfo<CreateHandle>(type, init_value);
148 template <
typename T>
151 TREELITE_CHECK(handle_);
152 T* out =
static_cast<T*
>(handle_.get());
157 template <
typename T>
160 TREELITE_CHECK(handle_);
161 const T* out =
static_cast<const T*
>(handle_.get());
167 Value::GetValueType()
const {
172 : pimpl_(new
TreeBuilderImpl(threshold_type, leaf_output_type)), ensemble_id_(nullptr) {}
173 TreeBuilder::~TreeBuilder() =
default;
177 auto& nodes = pimpl_->tree.nodes;
178 TREELITE_CHECK_EQ(nodes.count(node_key), 0)
179 <<
"CreateNode: nodes with duplicate keys are not allowed";
180 nodes[node_key] = std::make_unique<NodeDraft>();
185 auto& tree = pimpl_->tree;
186 auto& nodes = tree.nodes;
187 TREELITE_CHECK_GT(nodes.count(node_key), 0) <<
"DeleteNode: no node found with node_key";
188 NodeDraft* node = nodes[node_key].get();
189 if (tree.root == node) {
192 if (node->left_child !=
nullptr) {
193 node->left_child->parent =
nullptr;
195 if (node->right_child !=
nullptr) {
196 node->right_child->parent =
nullptr;
198 if (node == tree.root) {
202 nodes.erase(node_key);
208 auto& tree = pimpl_->tree;
209 auto& nodes = tree.nodes;
210 TREELITE_CHECK_GT(nodes.count(node_key), 0) <<
"SetRootNode: no node found with node_key";
211 NodeDraft* node = nodes[node_key].get();
212 TREELITE_CHECK(!node->parent) <<
"SetRootNode: a root node cannot have a parent";
218 Value threshold,
bool default_left,
int left_child_key,
219 int right_child_key) {
220 TREELITE_CHECK_GT(
optable.count(opname), 0) <<
"No operator \"" << opname <<
"\" exists";
223 left_child_key, right_child_key);
228 bool default_left,
int left_child_key,
int right_child_key) {
229 auto& tree = pimpl_->tree;
230 auto& nodes = tree.nodes;
231 TREELITE_CHECK(tree.threshold_type == threshold.GetValueType())
232 <<
"SetNumericalTestNode: threshold has an incorrect type. " 235 TREELITE_CHECK_GT(nodes.count(node_key), 0)
236 <<
"SetNumericalTestNode: no node found with node_key";
237 TREELITE_CHECK_GT(nodes.count(left_child_key), 0)
238 <<
"SetNumericalTestNode: no node found with left_child_key";
239 TREELITE_CHECK_GT(nodes.count(right_child_key), 0)
240 <<
"SetNumericalTestNode: no node found with right_child_key";
241 NodeDraft* node = nodes[node_key].get();
242 NodeDraft* left_child = nodes[left_child_key].get();
243 NodeDraft* right_child = nodes[right_child_key].get();
244 TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
245 <<
"SetNumericalTestNode: cannot modify a non-empty node";
246 TREELITE_CHECK(!left_child->parent)
247 <<
"SetNumericalTestNode: node designated as left child already has a parent";
248 TREELITE_CHECK(!right_child->parent)
249 <<
"SetNumericalTestNode: node designated as right child already has a parent";
250 TREELITE_CHECK(left_child != tree.root && right_child != tree.root)
251 <<
"SetNumericalTestNode: the root node cannot be a child";
252 node->status = NodeDraft::Status::kNumericalTest;
253 node->left_child = nodes[left_child_key].get();
254 node->left_child->parent = node;
255 node->right_child = nodes[right_child_key].get();
256 node->right_child->parent = node;
257 node->feature_id = feature_id;
258 node->default_left = default_left;
259 node->threshold = std::move(threshold);
265 const std::vector<uint32_t>& left_categories,
bool default_left,
266 int left_child_key,
int right_child_key) {
267 auto &tree = pimpl_->tree;
268 auto &nodes = tree.nodes;
269 TREELITE_CHECK_GT(nodes.count(node_key), 0)
270 <<
"SetCategoricalTestNode: no node found with node_key";
271 TREELITE_CHECK_GT(nodes.count(left_child_key), 0)
272 <<
"SetCategoricalTestNode: no node found with left_child_key";
273 TREELITE_CHECK_GT(nodes.count(right_child_key), 0)
274 <<
"SetCategoricalTestNode: no node found with right_child_key";
275 NodeDraft* node = nodes[node_key].get();
276 NodeDraft* left_child = nodes[left_child_key].get();
277 NodeDraft* right_child = nodes[right_child_key].get();
278 TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
279 <<
"SetCategoricalTestNode: cannot modify a non-empty node";
280 TREELITE_CHECK(!left_child->parent)
281 <<
"SetCategoricalTestNode: node designated as left child already has a parent";
282 TREELITE_CHECK(!right_child->parent)
283 <<
"SetCategoricalTestNode: node designated as right child already has a parent";
284 TREELITE_CHECK(left_child != tree.root && right_child != tree.root)
285 <<
"SetCategoricalTestNode: the root node cannot be a child";
286 node->status = NodeDraft::Status::kCategoricalTest;
287 node->left_child = nodes[left_child_key].get();
288 node->left_child->parent = node;
289 node->right_child = nodes[right_child_key].get();
290 node->right_child->parent = node;
291 node->feature_id = feature_id;
292 node->default_left = default_left;
293 node->left_categories = left_categories;
298 auto& tree = pimpl_->tree;
299 auto& nodes = tree.nodes;
300 TREELITE_CHECK(tree.leaf_output_type == leaf_value.GetValueType())
301 <<
"SetLeafNode: leaf_value has an incorrect type. " 304 TREELITE_CHECK_GT(nodes.count(node_key), 0) <<
"SetLeafNode: no node found with node_key";
305 NodeDraft* node = nodes[node_key].get();
306 TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
307 <<
"SetLeafNode: cannot modify a non-empty node";
308 node->status = NodeDraft::Status::kLeaf;
309 node->leaf_value = std::move(leaf_value);
314 auto& tree = pimpl_->tree;
315 auto& nodes = tree.nodes;
316 const size_t leaf_vector_len = leaf_vector.size();
317 for (
size_t i = 0; i < leaf_vector_len; ++i) {
318 const Value& leaf_value = leaf_vector[i];
319 TREELITE_CHECK(tree.leaf_output_type == leaf_value.GetValueType())
320 <<
"SetLeafVectorNode: the element " << i <<
" in leaf_vector has an incorrect type. " 324 TREELITE_CHECK_GT(nodes.count(node_key), 0)
325 <<
"SetLeafVectorNode: no node found with node_key";
326 NodeDraft* node = nodes[node_key].get();
327 TREELITE_CHECK(node->status == NodeDraft::Status::kEmpty)
328 <<
"SetLeafVectorNode: cannot modify a non-empty node";
329 node->status = NodeDraft::Status::kLeaf;
330 node->leaf_vector = leaf_vector;
336 threshold_type, leaf_output_type)) {}
337 ModelBuilder::~ModelBuilder() =
default;
341 pimpl_->cfg.emplace_back(name, value);
346 if (tree_builder ==
nullptr) {
347 TREELITE_LOG(FATAL) <<
"InsertTree: not a valid tree builder";
350 if (tree_builder->ensemble_id_ !=
nullptr) {
351 TREELITE_LOG(FATAL) <<
"InsertTree: tree is already part of another ensemble";
354 if (tree_builder->pimpl_->tree.threshold_type != this->pimpl_->threshold_type) {
356 <<
"InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " 358 <<
" type for split thresholds whereas the tree is using " 362 if (tree_builder->pimpl_->tree.leaf_output_type != this->pimpl_->leaf_output_type) {
364 <<
"InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " 365 <<
"member trees to use " <<
TypeInfoToString(this->pimpl_->leaf_output_type)
366 <<
" type for leaf outputs whereas the tree is using " 372 for (
const auto& kv : tree_builder->pimpl_->tree.nodes) {
373 const NodeDraft::Status status = kv.second->status;
374 if (status == NodeDraft::Status::kNumericalTest ||
375 status == NodeDraft::Status::kCategoricalTest) {
376 const int fid =
static_cast<int>(kv.second->feature_id);
377 if (fid < 0 || fid >= this->pimpl_->num_feature) {
378 TREELITE_LOG(FATAL) <<
"InsertTree: tree has an invalid split at node " 379 << kv.first <<
": feature id " 380 << kv.second->feature_id <<
" is out of bound";
387 auto& trees = pimpl_->trees;
389 trees.push_back(std::move(*tree_builder));
390 tree_builder->ensemble_id_ =
this;
391 return static_cast<int>(trees.size());
393 if (static_cast<size_t>(index) <= trees.size()) {
394 trees.insert(trees.begin() + index, std::move(*tree_builder));
395 tree_builder->ensemble_id_ =
this;
398 TREELITE_LOG(FATAL) <<
"InsertTree: index out of bound";
406 return &pimpl_->trees.at(index);
411 return &pimpl_->trees.at(index);
416 auto& trees = pimpl_->trees;
417 TREELITE_CHECK_LT(static_cast<size_t>(index), trees.size())
418 <<
"DeleteTree: index out of bound";
419 trees.erase(trees.begin() + index);
422 std::unique_ptr<Model>
424 std::unique_ptr<Model> model_ptr = Model::Create(pimpl_->threshold_type,
425 pimpl_->leaf_output_type);
426 model_ptr->Dispatch([
this](
auto& model) {
427 this->pimpl_->CommitModelImpl(&model);
432 template <
typename ThresholdType,
typename LeafOutputType>
438 model.
task_param.output_type = TaskParam::OutputType::kFloat;
441 InitParamAndCheck(&model.
param, this->cfg);
447 int8_t flag_leaf_vector = -1;
449 for (
const auto& tree_builder : this->trees) {
450 const auto& _tree = tree_builder.pimpl_->tree;
451 TREELITE_CHECK(_tree.root) <<
"CommitModel: a tree has no root node";
452 TREELITE_CHECK(_tree.root->status != NodeDraft::Status::kEmpty)
453 <<
"SetRootNode: cannot set an empty node as root";
454 model.
trees.emplace_back();
460 std::queue<std::pair<const NodeDraft*, int>> Q;
461 Q.push({_tree.root, 0});
463 const NodeDraft* node;
465 std::tie(node, nid) = Q.front();
467 TREELITE_CHECK(node->status != NodeDraft::Status::kEmpty)
468 <<
"CommitModel: encountered an empty node in the middle of a tree";
469 if (node->status == NodeDraft::Status::kNumericalTest) {
470 TREELITE_CHECK(node->left_child) <<
"CommitModel: a test node lacks a left child";
471 TREELITE_CHECK(node->right_child) <<
"CommitModel: a test node lacks a right child";
472 TREELITE_CHECK(node->left_child->parent == node)
473 <<
"CommitModel: left child has wrong parent";
474 TREELITE_CHECK(node->right_child->parent == node)
475 <<
"CommitModel: right child has wrong parent";
477 TREELITE_CHECK(node->threshold.GetValueType() == TypeToInfo<ThresholdType>())
478 <<
"CommitModel: The specified threshold has incorrect type. Expected: " 481 ThresholdType threshold = node->threshold.Get<ThresholdType>();
482 tree.
SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op);
483 Q.push({node->left_child, tree.
LeftChild(nid)});
484 Q.push({node->right_child, tree.
RightChild(nid)});
485 }
else if (node->status == NodeDraft::Status::kCategoricalTest) {
486 TREELITE_CHECK(node->left_child) <<
"CommitModel: a test node lacks a left child";
487 TREELITE_CHECK(node->right_child) <<
"CommitModel: a test node lacks a right child";
488 TREELITE_CHECK(node->left_child->parent == node)
489 <<
"CommitModel: left child has wrong parent";
490 TREELITE_CHECK(node->right_child->parent == node)
491 <<
"CommitModel: right child has wrong parent";
495 Q.push({node->left_child, tree.
LeftChild(nid)});
496 Q.push({node->right_child, tree.
RightChild(nid)});
498 TREELITE_CHECK(node->left_child ==
nullptr && node->right_child ==
nullptr)
499 <<
"CommitModel: a leaf node cannot have children";
500 if (!node->leaf_vector.empty()) {
501 TREELITE_CHECK_NE(flag_leaf_vector, 0)
502 <<
"CommitModel: Inconsistent use of leaf vector: if one leaf node uses a leaf vector, " 503 <<
"*every* leaf node must use a leaf vector";
504 flag_leaf_vector = 1;
505 TREELITE_CHECK_EQ(node->leaf_vector.size(), model.
task_param.num_class)
506 <<
"CommitModel: The length of leaf vector must be identical to the number of output " 508 SetLeafVector(&tree, nid, node->leaf_vector);
510 TREELITE_CHECK_NE(flag_leaf_vector, 1)
511 <<
"CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf " 512 <<
"vector, *no other* leaf node can use a leaf vector";
513 flag_leaf_vector = 0;
514 TREELITE_CHECK(node->leaf_value.GetValueType() == TypeToInfo<LeafOutputType>())
515 <<
"CommitModel: The specified leaf value has incorrect type. Expected: " 518 LeafOutputType leaf_value = node->leaf_value.Get<LeafOutputType>();
524 if (flag_leaf_vector == 0) {
528 model.
task_type = TaskType::kMultiClfGrovePerClass;
530 TREELITE_CHECK_EQ(this->trees.size() % model.
task_param.num_class, 0)
531 <<
"For multi-class classifiers with gradient boosted trees, the number of trees must be " 532 <<
"evenly divisible by the number of output groups";
535 model.
task_type = TaskType::kBinaryClfRegr;
538 }
else if (flag_leaf_vector == 1) {
540 model.
task_type = TaskType::kMultiClfProbDistLeaf;
542 TREELITE_CHECK_GT(model.
task_param.num_class, 1)
543 <<
"Expected leaf vectors with length exceeding 1";
546 TREELITE_LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
550 template Value Value::Create(uint32_t init_value);
551 template Value Value::Create(
float init_value);
552 template Value Value::Create(
double init_value);
ModelParam param
extra parameters
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
void DeleteNode(int node_key)
Remove a node from a tree.
bool average_tree_output
whether to average tree outputs
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
void SetRootNode(int node_key)
Set a node as the root of a tree.
in-memory representation of a decision tree
logging facility for Treelite
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
void SetLeafVector(int nid, const std::vector< LeafOutputType > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
TaskType task_type
Task type.
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
void SetCategoricalTestNode(int node_key, unsigned feature_id, const std::vector< uint32_t > &left_categories, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a categorical test node. A list defines all categories that would be classifi...
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
void DeleteTree(int index)
Remove a tree from the ensemble.
const std::unordered_map< std::string, Operator > optable
conversion table from string to Operator, defined in tables.cc
int LeftChild(int nid) const
Getters.
TaskParam task_param
Group of parameters that are specific to the particular task type.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
void CreateNode(int node_key)
Create an empty node within a tree.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetLeafVectorNode(int node_key, const std::vector< Value > &leaf_vector)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
ModelBuilder(int num_feature, int num_class, bool average_tree_output, TypeInfo threshold_type, TypeInfo leaf_output_type)
Constructor.
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Operator
comparison operators
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node