8 #include <dmlc/registry.h> 18 enum class Status : int8_t {
19 kEmpty, kNumericalTest, kCategoricalTest, kLeaf
24 std::vector<treelite::frontend::Value> leaf_vector;
28 NodeDraft* left_child;
29 NodeDraft* right_child;
48 std::vector<uint32_t> left_categories;
51 : status(Status::kEmpty), parent(nullptr), left_child(nullptr), right_child(nullptr) {}
56 std::unordered_map<int, std::unique_ptr<NodeDraft>> nodes;
60 : root(nullptr), nodes(), threshold_type(threshold_type), leaf_output_type(leaf_output_type) {}
68 DMLC_REGISTRY_FILE_TAG(builder);
73 : tree(threshold_type, leaf_output_type) {}
77 std::vector<TreeBuilder> trees;
80 bool average_tree_output;
83 std::vector<std::pair<std::string, std::string>> cfg;
84 inline ModelBuilderImpl(
int num_feature,
int num_class,
bool average_tree_output,
86 : trees(), num_feature(num_feature), num_class(num_class),
87 average_tree_output(average_tree_output), threshold_type(threshold_type),
88 leaf_output_type(leaf_output_type), cfg() {
89 CHECK_GT(num_feature, 0) <<
"ModelBuilder: num_feature must be positive";
90 CHECK_GT(num_class, 0) <<
"ModelBuilder: num_class must be positive";
91 CHECK(threshold_type != TypeInfo::kInvalid)
92 <<
"ModelBuilder: threshold_type can't be invalid";
93 CHECK(leaf_output_type != TypeInfo::kInvalid)
94 <<
"ModelBuilder: leaf_output_type can't be invalid";
97 template <
typename ThresholdType,
typename LeafOutputType>
101 template <
typename ThresholdType,
typename LeafOutputType>
103 const std::vector<Value>& leaf_vector) {
104 const size_t leaf_vector_size = leaf_vector.size();
105 const TypeInfo expected_leaf_type = TypeToInfo<LeafOutputType>();
106 std::vector<LeafOutputType> out_leaf_vector;
107 for (
size_t i = 0; i < leaf_vector_size; ++i) {
108 const Value& leaf_value = leaf_vector[i];
109 CHECK(leaf_value.GetValueType() == expected_leaf_type)
110 <<
"Leaf value at index " << i <<
" has incorrect type. Expected: " 113 out_leaf_vector.push_back(leaf_value.Get<LeafOutputType>());
118 Value::Value() : handle_(
nullptr), type_(TypeInfo::kInvalid) {}
120 template <
typename T>
122 Value::Create(T init_value) {
124 std::unique_ptr<T> ptr = std::make_unique<T>(init_value);
125 value.handle_.reset(ptr.release());
126 value.type_ = TypeToInfo<T>();
130 template <
typename ValueType>
133 inline static std::shared_ptr<void> Dispatch(
const void* init_value) {
134 const auto* v_ptr =
static_cast<const ValueType*
>(init_value);
136 ValueType v = *v_ptr;
137 return std::make_shared<ValueType>(v);
142 Value::Create(
const void* init_value,
TypeInfo type) {
144 CHECK(type != TypeInfo::kInvalid) <<
"Type must be valid";
146 value.handle_ = DispatchWithTypeInfo<CreateHandle>(type, init_value);
150 template <
typename T>
154 T* out =
static_cast<T*
>(handle_.get());
159 template <
typename T>
163 const T* out =
static_cast<const T*
>(handle_.get());
169 Value::GetValueType()
const {
174 : pimpl_(new
TreeBuilderImpl(threshold_type, leaf_output_type)), ensemble_id_(nullptr) {}
175 TreeBuilder::~TreeBuilder() =
default;
179 auto& nodes = pimpl_->tree.nodes;
180 CHECK_EQ(nodes.count(node_key), 0) <<
"CreateNode: nodes with duplicate keys are not allowed";
181 nodes[node_key] = std::make_unique<NodeDraft>();
186 auto& tree = pimpl_->tree;
187 auto& nodes = tree.nodes;
188 CHECK_GT(nodes.count(node_key), 0) <<
"DeleteNode: no node found with node_key";
189 NodeDraft* node = nodes[node_key].get();
190 if (tree.root == node) {
193 if (node->left_child !=
nullptr) {
194 node->left_child->parent =
nullptr;
196 if (node->right_child !=
nullptr) {
197 node->right_child->parent =
nullptr;
199 if (node == tree.root) {
203 nodes.erase(node_key);
209 auto& tree = pimpl_->tree;
210 auto& nodes = tree.nodes;
211 CHECK_GT(nodes.count(node_key), 0) <<
"SetRootNode: no node found with node_key";
212 NodeDraft* node = nodes[node_key].get();
213 CHECK(!node->parent) <<
"SetRootNode: a root node cannot have a parent";
219 Value threshold,
bool default_left,
int left_child_key,
220 int right_child_key) {
221 CHECK_GT(
optable.count(opname), 0) <<
"No operator \"" << opname <<
"\" exists";
224 left_child_key, right_child_key);
229 bool default_left,
int left_child_key,
int right_child_key) {
230 auto& tree = pimpl_->tree;
231 auto& nodes = tree.nodes;
232 CHECK(tree.threshold_type == threshold.GetValueType())
233 <<
"SetNumericalTestNode: threshold has an incorrect type. " 236 CHECK_GT(nodes.count(node_key), 0) <<
"SetNumericalTestNode: no node found with node_key";
237 CHECK_GT(nodes.count(left_child_key), 0)
238 <<
"SetNumericalTestNode: no node found with left_child_key";
239 CHECK_GT(nodes.count(right_child_key), 0)
240 <<
"SetNumericalTestNode: no node found with right_child_key";
241 NodeDraft* node = nodes[node_key].get();
242 NodeDraft* left_child = nodes[left_child_key].get();
243 NodeDraft* right_child = nodes[right_child_key].get();
244 CHECK(node->status == NodeDraft::Status::kEmpty)
245 <<
"SetNumericalTestNode: cannot modify a non-empty node";
246 CHECK(!left_child->parent)
247 <<
"SetNumericalTestNode: node designated as left child already has a parent";
248 CHECK(!right_child->parent)
249 <<
"SetNumericalTestNode: node designated as right child already has a parent";
250 CHECK(left_child != tree.root && right_child != tree.root)
251 <<
"SetNumericalTestNode: the root node cannot be a child";
252 node->status = NodeDraft::Status::kNumericalTest;
253 node->left_child = nodes[left_child_key].get();
254 node->left_child->parent = node;
255 node->right_child = nodes[right_child_key].get();
256 node->right_child->parent = node;
257 node->feature_id = feature_id;
258 node->default_left = default_left;
259 node->threshold = std::move(threshold);
265 const std::vector<uint32_t>& left_categories,
bool default_left,
266 int left_child_key,
int right_child_key) {
267 auto &tree = pimpl_->tree;
268 auto &nodes = tree.nodes;
269 CHECK_GT(nodes.count(node_key), 0) <<
"SetCategoricalTestNode: no node found with node_key";
270 CHECK_GT(nodes.count(left_child_key), 0)
271 <<
"SetCategoricalTestNode: no node found with left_child_key";
272 CHECK_GT(nodes.count(right_child_key), 0)
273 <<
"SetCategoricalTestNode: no node found with right_child_key";
274 NodeDraft* node = nodes[node_key].get();
275 NodeDraft* left_child = nodes[left_child_key].get();
276 NodeDraft* right_child = nodes[right_child_key].get();
277 CHECK(node->status == NodeDraft::Status::kEmpty)
278 <<
"SetCategoricalTestNode: cannot modify a non-empty node";
279 CHECK(!left_child->parent)
280 <<
"SetCategoricalTestNode: node designated as left child already has a parent";
281 CHECK(!right_child->parent)
282 <<
"SetCategoricalTestNode: node designated as right child already has a parent";
283 CHECK(left_child != tree.root && right_child != tree.root)
284 <<
"SetCategoricalTestNode: the root node cannot be a child";
285 node->status = NodeDraft::Status::kCategoricalTest;
286 node->left_child = nodes[left_child_key].get();
287 node->left_child->parent = node;
288 node->right_child = nodes[right_child_key].get();
289 node->right_child->parent = node;
290 node->feature_id = feature_id;
291 node->default_left = default_left;
292 node->left_categories = left_categories;
297 auto& tree = pimpl_->tree;
298 auto& nodes = tree.nodes;
299 CHECK(tree.leaf_output_type == leaf_value.GetValueType())
300 <<
"SetLeafNode: leaf_value has an incorrect type. " 303 CHECK_GT(nodes.count(node_key), 0) <<
"SetLeafNode: no node found with node_key";
304 NodeDraft* node = nodes[node_key].get();
305 CHECK(node->status == NodeDraft::Status::kEmpty) <<
"SetLeafNode: cannot modify a non-empty node";
306 node->status = NodeDraft::Status::kLeaf;
307 node->leaf_value = std::move(leaf_value);
312 auto& tree = pimpl_->tree;
313 auto& nodes = tree.nodes;
314 const size_t leaf_vector_len = leaf_vector.size();
315 for (
size_t i = 0; i < leaf_vector_len; ++i) {
316 const Value& leaf_value = leaf_vector[i];
317 CHECK(tree.leaf_output_type == leaf_value.GetValueType())
318 <<
"SetLeafVectorNode: the element " << i <<
" in leaf_vector has an incorrect type. " 322 CHECK_GT(nodes.count(node_key), 0) <<
"SetLeafVectorNode: no node found with node_key";
323 NodeDraft* node = nodes[node_key].get();
324 CHECK(node->status == NodeDraft::Status::kEmpty)
325 <<
"SetLeafVectorNode: cannot modify a non-empty node";
326 node->status = NodeDraft::Status::kLeaf;
327 node->leaf_vector = leaf_vector;
333 threshold_type, leaf_output_type)) {}
334 ModelBuilder::~ModelBuilder() =
default;
338 pimpl_->cfg.emplace_back(name, value);
343 if (tree_builder ==
nullptr) {
344 LOG(FATAL) <<
"InsertTree: not a valid tree builder";
347 if (tree_builder->ensemble_id_ !=
nullptr) {
348 LOG(FATAL) <<
"InsertTree: tree is already part of another ensemble";
351 if (tree_builder->pimpl_->tree.threshold_type != this->pimpl_->threshold_type) {
353 <<
"InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " 355 <<
" type for split thresholds whereas the tree is using " 359 if (tree_builder->pimpl_->tree.leaf_output_type != this->pimpl_->leaf_output_type) {
361 <<
"InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " 362 <<
"member trees to use " <<
TypeInfoToString(this->pimpl_->leaf_output_type)
363 <<
" type for leaf outputs whereas the tree is using " 369 for (
const auto& kv : tree_builder->pimpl_->tree.nodes) {
370 const NodeDraft::Status status = kv.second->status;
371 if (status == NodeDraft::Status::kNumericalTest ||
372 status == NodeDraft::Status::kCategoricalTest) {
373 const int fid =
static_cast<int>(kv.second->feature_id);
374 if (fid < 0 || fid >= this->pimpl_->num_feature) {
375 LOG(FATAL) <<
"InsertTree: tree has an invalid split at node " 376 << kv.first <<
": feature id " << kv.second->feature_id <<
" is out of bound";
383 auto& trees = pimpl_->trees;
385 trees.push_back(std::move(*tree_builder));
386 tree_builder->ensemble_id_ =
this;
387 return static_cast<int>(trees.size());
389 if (static_cast<size_t>(index) <= trees.size()) {
390 trees.insert(trees.begin() + index, std::move(*tree_builder));
391 tree_builder->ensemble_id_ =
this;
394 LOG(FATAL) <<
"InsertTree: index out of bound";
402 return &pimpl_->trees.at(index);
407 return &pimpl_->trees.at(index);
412 auto& trees = pimpl_->trees;
413 CHECK_LT(static_cast<size_t>(index), trees.size()) <<
"DeleteTree: index out of bound";
414 trees.erase(trees.begin() + index);
417 std::unique_ptr<Model>
419 std::unique_ptr<Model> model_ptr = Model::Create(pimpl_->threshold_type,
420 pimpl_->leaf_output_type);
421 model_ptr->Dispatch([
this](
auto& model) {
422 this->pimpl_->CommitModelImpl(&model);
427 template <
typename ThresholdType,
typename LeafOutputType>
433 model.
task_param.output_type = TaskParameter::OutputType::kFloat;
436 InitParamAndCheck(&model.
param, this->cfg);
442 int8_t flag_leaf_vector = -1;
444 for (
const auto& tree_builder : this->trees) {
445 const auto& _tree = tree_builder.pimpl_->tree;
446 CHECK(_tree.root) <<
"CommitModel: a tree has no root node";
447 CHECK(_tree.root->status != NodeDraft::Status::kEmpty)
448 <<
"SetRootNode: cannot set an empty node as root";
449 model.
trees.emplace_back();
455 std::queue<std::pair<const NodeDraft*, int>> Q;
456 Q.push({_tree.root, 0});
458 const NodeDraft* node;
460 std::tie(node, nid) = Q.front();
462 CHECK(node->status != NodeDraft::Status::kEmpty)
463 <<
"CommitModel: encountered an empty node in the middle of a tree";
464 if (node->status == NodeDraft::Status::kNumericalTest) {
465 CHECK(node->left_child) <<
"CommitModel: a test node lacks a left child";
466 CHECK(node->right_child) <<
"CommitModel: a test node lacks a right child";
467 CHECK(node->left_child->parent == node) <<
"CommitModel: left child has wrong parent";
468 CHECK(node->right_child->parent == node) <<
"CommitModel: right child has wrong parent";
470 CHECK(node->threshold.GetValueType() == TypeToInfo<ThresholdType>())
471 <<
"CommitModel: The specified threshold has incorrect type. Expected: " 474 ThresholdType threshold = node->threshold.Get<ThresholdType>();
475 tree.
SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op);
476 Q.push({node->left_child, tree.
LeftChild(nid)});
477 Q.push({node->right_child, tree.
RightChild(nid)});
478 }
else if (node->status == NodeDraft::Status::kCategoricalTest) {
479 CHECK(node->left_child) <<
"CommitModel: a test node lacks a left child";
480 CHECK(node->right_child) <<
"CommitModel: a test node lacks a right child";
481 CHECK(node->left_child->parent == node) <<
"CommitModel: left child has wrong parent";
482 CHECK(node->right_child->parent == node) <<
"CommitModel: right child has wrong parent";
486 Q.push({node->left_child, tree.
LeftChild(nid)});
487 Q.push({node->right_child, tree.
RightChild(nid)});
489 CHECK(node->left_child ==
nullptr && node->right_child ==
nullptr)
490 <<
"CommitModel: a leaf node cannot have children";
491 if (!node->leaf_vector.empty()) {
492 CHECK_NE(flag_leaf_vector, 0)
493 <<
"CommitModel: Inconsistent use of leaf vector: if one leaf node uses a leaf vector, " 494 <<
"*every* leaf node must use a leaf vector";
495 flag_leaf_vector = 1;
496 CHECK_EQ(node->leaf_vector.size(), model.
task_param.num_class)
497 <<
"CommitModel: The length of leaf vector must be identical to the number of output " 499 SetLeafVector(&tree, nid, node->leaf_vector);
501 CHECK_NE(flag_leaf_vector, 1)
502 <<
"CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf " 503 <<
"vector, *no other* leaf node can use a leaf vector";
504 flag_leaf_vector = 0;
505 CHECK(node->leaf_value.GetValueType() == TypeToInfo<LeafOutputType>())
506 <<
"CommitModel: The specified leaf value has incorrect type. Expected: " 509 LeafOutputType leaf_value = node->leaf_value.Get<LeafOutputType>();
515 if (flag_leaf_vector == 0) {
519 model.
task_type = TaskType::kMultiClfGrovePerClass;
521 CHECK_EQ(this->trees.size() % model.
task_param.num_class, 0)
522 <<
"For multi-class classifiers with gradient boosted trees, the number of trees must be " 523 <<
"evenly divisible by the number of output groups";
526 model.
task_type = TaskType::kBinaryClfRegr;
529 }
else if (flag_leaf_vector == 1) {
531 model.
task_type = TaskType::kMultiClfProbDistLeaf;
533 CHECK_GT(model.
task_param.num_class, 1) <<
"Expected leaf vectors with length exceeding 1";
536 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
540 template Value Value::Create(uint32_t init_value);
541 template Value Value::Create(
float init_value);
542 template Value Value::Create(
double init_value);
ModelParam param
extra parameters
std::unique_ptr< Model > CommitModel()
finalize the model and produce the in-memory representation
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
void DeleteNode(int node_key)
Remove a node from a tree.
bool average_tree_output
whether to average tree outputs
void SetNumericalTestNode(int node_key, unsigned feature_id, const char *op, Value threshold, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a numerical test node; the test is in the form [feature value] OP [threshold]...
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
void SetLeafNode(int node_key, Value leaf_value)
Turn an empty node into a leaf node.
void SetRootNode(int node_key)
Set a node as the root of a tree.
in-memory representation of a decision tree
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
void SetLeafVector(int nid, const std::vector< LeafOutputType > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
TaskType task_type
Task type.
TypeInfo
Types used by thresholds and leaf outputs.
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
void SetCategoricalTestNode(int node_key, unsigned feature_id, const std::vector< uint32_t > &left_categories, bool default_left, int left_child_key, int right_child_key)
Turn an empty node into a categorical test node. A list defines all categories that would be classifi...
TaskParameter task_param
Group of parameters that are specific to the particular task type.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
void DeleteTree(int index)
Remove a tree from the ensemble.
const std::unordered_map< std::string, Operator > optable
conversion table from string to Operator, defined in tables.cc
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
void CreateNode(int node_key)
Create an empty node within a tree.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetLeafVectorNode(int node_key, const std::vector< Value > &leaf_vector)
Turn an empty node into a leaf vector node The leaf vector (collection of multiple leaf weights per l...
ModelBuilder(int num_feature, int num_class, bool average_tree_output, TypeInfo threshold_type, TypeInfo leaf_output_type)
Constructor.
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Operator
comparison operators
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node