8 #include <dmlc/registry.h> 19 enum class _Status : int8_t {
20 kEmpty, kNumericalTest, kCategoricalTest, kLeaf
29 std::vector<treelite::tl_float> leaf_vector;
51 std::vector<uint32_t> left_categories;
54 : status(_Status::kEmpty),
55 parent(nullptr), left_child(nullptr), right_child(nullptr) {}
60 std::unordered_map<int, std::unique_ptr<_Node>> nodes;
61 inline _Tree() : root(nullptr), nodes() {}
69 DMLC_REGISTRY_FILE_TAG(builder);
77 std::vector<TreeBuilder> trees;
80 bool random_forest_flag;
81 std::vector<std::pair<std::string, std::string>> cfg;
83 bool random_forest_flag)
84 : trees(), num_feature(num_feature),
85 num_output_group(num_output_group),
86 random_forest_flag(random_forest_flag), cfg() {
87 CHECK_GT(num_feature, 0) <<
"ModelBuilder: num_feature must be positive";
88 CHECK_GT(num_output_group, 0)
89 <<
"ModelBuilder: num_output_group must be positive";
93 TreeBuilder::TreeBuilder()
95 TreeBuilder::~TreeBuilder() {}
98 TreeBuilder::CreateNode(
int node_key) {
99 auto& nodes = pimpl->tree.nodes;
100 CHECK_EQ(nodes.count(node_key), 0) <<
"CreateNode: nodes with duplicate keys are not allowed";
101 nodes[node_key].reset(
new _Node());
105 TreeBuilder::DeleteNode(
int node_key) {
106 auto& tree = pimpl->tree;
107 auto& nodes = tree.nodes;
108 CHECK_GT(nodes.count(node_key), 0) <<
"DeleteNode: no node found with node_key";
109 _Node* node = nodes[node_key].get();
110 if (tree.root == node) {
113 if (node->left_child !=
nullptr) {
114 node->left_child->parent =
nullptr;
116 if (node->right_child !=
nullptr) {
117 node->right_child->parent =
nullptr;
119 if (node == tree.root) {
123 nodes.erase(node_key);
128 TreeBuilder::SetRootNode(
int node_key) {
129 auto& tree = pimpl->tree;
130 auto& nodes = tree.nodes;
131 CHECK_GT(nodes.count(node_key), 0) <<
"SetRootNode: no node found with node_key";
132 _Node* node = nodes[node_key].get();
133 CHECK(!node->parent) <<
"SetRootNode: a root node cannot have a parent";
138 TreeBuilder::SetNumericalTestNode(
int node_key,
140 const char* opname,
tl_float threshold,
141 bool default_left,
int left_child_key,
142 int right_child_key) {
143 CHECK_GT(
optable.count(opname), 0) <<
"No operator \"" << opname <<
"\" exists";
145 SetNumericalTestNode(node_key, feature_id, op, threshold, default_left,
146 left_child_key, right_child_key);
150 TreeBuilder::SetNumericalTestNode(
int node_key,
153 bool default_left,
int left_child_key,
154 int right_child_key) {
155 auto& tree = pimpl->tree;
156 auto& nodes = tree.nodes;
157 CHECK_GT(nodes.count(node_key), 0) <<
"SetNumericalTestNode: no node found with node_key";
158 CHECK_GT(nodes.count(left_child_key), 0)
159 <<
"SetNumericalTestNode: no node found with left_child_key";
160 CHECK_GT(nodes.count(right_child_key), 0)
161 <<
"SetNumericalTestNode: no node found with right_child_key";
162 _Node* node = nodes[node_key].get();
163 _Node* left_child = nodes[left_child_key].get();
164 _Node* right_child = nodes[right_child_key].get();
165 CHECK(node->status == _Node::_Status::kEmpty)
166 <<
"SetNumericalTestNode: cannot modify a non-empty node";
167 CHECK(!left_child->parent)
168 <<
"SetNumericalTestNode: node designated as left child already has a parent";
169 CHECK(!right_child->parent)
170 <<
"SetNumericalTestNode: node designated as right child already has a parent";
171 CHECK(left_child != tree.root && right_child != tree.root)
172 <<
"SetNumericalTestNode: the root node cannot be a child";
173 node->status = _Node::_Status::kNumericalTest;
174 node->left_child = nodes[left_child_key].get();
175 node->left_child->parent = node;
176 node->right_child = nodes[right_child_key].get();
177 node->right_child->parent = node;
178 node->feature_id = feature_id;
179 node->default_left = default_left;
180 node->info.threshold = threshold;
185 TreeBuilder::SetCategoricalTestNode(
int node_key,
187 const std::vector<uint32_t>& left_categories,
188 bool default_left,
int left_child_key,
189 int right_child_key) {
190 auto &tree = pimpl->tree;
191 auto &nodes = tree.nodes;
192 CHECK_GT(nodes.count(node_key), 0) <<
"SetCategoricalTestNode: no node found with node_key";
193 CHECK_GT(nodes.count(left_child_key), 0)
194 <<
"SetCategoricalTestNode: no node found with left_child_key";
195 CHECK_GT(nodes.count(right_child_key), 0)
196 <<
"SetCategoricalTestNode: no node found with right_child_key";
197 _Node* node = nodes[node_key].get();
198 _Node* left_child = nodes[left_child_key].get();
199 _Node* right_child = nodes[right_child_key].get();
200 CHECK(node->status == _Node::_Status::kEmpty)
201 <<
"SetCategoricalTestNode: cannot modify a non-empty node";
202 CHECK(!left_child->parent)
203 <<
"SetCategoricalTestNode: node designated as left child already has a parent";
204 CHECK(!right_child->parent)
205 <<
"SetCategoricalTestNode: node designated as right child already has a parent";
206 CHECK(left_child != tree.root && right_child != tree.root)
207 <<
"SetCategoricalTestNode: the root node cannot be a child";
208 node->status = _Node::_Status::kCategoricalTest;
209 node->left_child = nodes[left_child_key].get();
210 node->left_child->parent = node;
211 node->right_child = nodes[right_child_key].get();
212 node->right_child->parent = node;
213 node->feature_id = feature_id;
214 node->default_left = default_left;
215 node->left_categories = left_categories;
219 TreeBuilder::SetLeafNode(
int node_key,
tl_float leaf_value) {
220 auto& tree = pimpl->tree;
221 auto& nodes = tree.nodes;
222 CHECK_GT(nodes.count(node_key), 0) <<
"SetLeafNode: no node found with node_key";
223 _Node* node = nodes[node_key].get();
224 CHECK(node->status == _Node::_Status::kEmpty) <<
"SetLeafNode: cannot modify a non-empty node";
225 node->status = _Node::_Status::kLeaf;
226 node->info.leaf_value = leaf_value;
230 TreeBuilder::SetLeafVectorNode(
int node_key,
const std::vector<tl_float>& leaf_vector) {
231 auto& tree = pimpl->tree;
232 auto& nodes = tree.nodes;
233 CHECK_GT(nodes.count(node_key), 0) <<
"SetLeafVectorNode: no node found with node_key";
234 _Node* node = nodes[node_key].get();
235 CHECK(node->status == _Node::_Status::kEmpty)
236 <<
"SetLeafVectorNode: cannot modify a non-empty node";
237 node->status = _Node::_Status::kLeaf;
238 node->leaf_vector = leaf_vector;
241 ModelBuilder::ModelBuilder(
int num_feature,
int num_output_group,
242 bool random_forest_flag)
245 random_forest_flag)) {}
246 ModelBuilder::~ModelBuilder() =
default;
250 pimpl->cfg.emplace_back(name, value);
255 if (tree_builder ==
nullptr) {
256 LOG(FATAL) <<
"InsertTree: not a valid tree builder";
259 if (tree_builder->ensemble_id !=
nullptr) {
260 LOG(FATAL) <<
"InsertTree: tree is already part of another ensemble";
265 for (
const auto& kv : tree_builder->pimpl->tree.nodes) {
266 const _Node::_Status status = kv.second->status;
267 if (status == _Node::_Status::kNumericalTest ||
268 status == _Node::_Status::kCategoricalTest) {
269 const int fid =
static_cast<int>(kv.second->feature_id);
270 if (fid < 0 || fid >= pimpl->num_feature) {
271 LOG(FATAL) <<
"InsertTree: tree has an invalid split at node " 272 << kv.first <<
": feature id " << kv.second->feature_id <<
" is out of bound";
279 auto& trees = pimpl->trees;
281 trees.push_back(std::move(*tree_builder));
282 tree_builder->ensemble_id =
static_cast<void*
>(
this);
283 return static_cast<int>(trees.size());
285 if (static_cast<size_t>(index) <= trees.size()) {
286 trees.insert(trees.begin() + index, std::move(*tree_builder));
287 tree_builder->ensemble_id =
static_cast<void*
>(
this);
290 LOG(FATAL) <<
"CreateTree: index out of bound";
298 return &pimpl->trees.at(index);
303 return &pimpl->trees.at(index);
308 auto& trees = pimpl->trees;
309 CHECK_LT(static_cast<size_t>(index), trees.size()) <<
"DeleteTree: index out of bound";
310 trees.erase(trees.begin() + index);
320 InitParamAndCheck(&model.
param, pimpl->cfg);
326 int8_t flag_leaf_vector = -1;
328 for (
const auto& _tree_builder : pimpl->trees) {
329 const auto& _tree = _tree_builder.pimpl->tree;
330 CHECK(_tree.root) <<
"CommitModel: a tree has no root node";
331 CHECK(_tree.root->status != _Node::_Status::kEmpty)
332 <<
"SetRootNode: cannot set an empty node as root";
333 model.
trees.emplace_back();
339 std::queue<std::pair<const _Node*, int>> Q;
340 Q.push({_tree.root, 0});
344 std::tie(node, nid) = Q.front();
346 CHECK(node->status != _Node::_Status::kEmpty)
347 <<
"CommitModel: encountered an empty node in the middle of a tree";
348 if (node->status == _Node::_Status::kNumericalTest) {
349 CHECK(node->left_child) <<
"CommitModel: a test node lacks a left child";
350 CHECK(node->right_child) <<
"CommitModel: a test node lacks a right child";
351 CHECK(node->left_child->parent == node) <<
"CommitModel: left child has wrong parent";
352 CHECK(node->right_child->parent == node) <<
"CommitModel: right child has wrong parent";
355 node->default_left, node->op);
356 Q.push({node->left_child, tree.
LeftChild(nid)});
357 Q.push({node->right_child, tree.
RightChild(nid)});
358 }
else if (node->status == _Node::_Status::kCategoricalTest) {
359 CHECK(node->left_child) <<
"CommitModel: a test node lacks a left child";
360 CHECK(node->right_child) <<
"CommitModel: a test node lacks a right child";
361 CHECK(node->left_child->parent == node) <<
"CommitModel: left child has wrong parent";
362 CHECK(node->right_child->parent == node) <<
"CommitModel: right child has wrong parent";
365 false, node->left_categories);
366 Q.push({node->left_child, tree.
LeftChild(nid)});
367 Q.push({node->right_child, tree.
RightChild(nid)});
369 CHECK(node->left_child ==
nullptr && node->right_child ==
nullptr)
370 <<
"CommitModel: a leaf node cannot have children";
371 if (!node->leaf_vector.empty()) {
372 CHECK_NE(flag_leaf_vector, 0)
373 <<
"CommitModel: Inconsistent use of leaf vector: if one leaf node uses a leaf vector, " 374 <<
"*every* leaf node must use a leaf vector";
375 flag_leaf_vector = 1;
377 <<
"CommitModel: The length of leaf vector must be identical to the number of output " 381 CHECK_NE(flag_leaf_vector, 1)
382 <<
"CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf " 383 <<
"vector, *no other* leaf node can use a leaf vector";
384 flag_leaf_vector = 0;
385 tree.
SetLeaf(nid, node->info.leaf_value);
390 if (flag_leaf_vector == 0) {
394 <<
"To use a random forest for multi-class classification, each leaf node must output a " 395 <<
"leaf vector specifying a probability distribution";
397 <<
"For multi-class classifiers with gradient boosted trees, the number of trees must be " 398 <<
"evenly divisible by the number of output groups";
400 }
else if (flag_leaf_vector == 1) {
403 <<
"In multi-class classifiers with gradient boosted trees, each leaf node must output a " 404 <<
"single floating-point value.";
406 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
408 *out_model = std::move(model);
void CommitModel(Model *out_model)
finalize the model and produce the in-memory representation
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::vector< Tree > trees
member trees
ModelParam param
extra parameters
model structure for tree ensemble
void SetModelParam(const char *name, const char *value)
Set a model parameter.
in-memory representation of a decision tree
TreeBuilder * GetTree(int index)
Get a reference to a tree in the ensemble.
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
void SetLeafVector(int nid, const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
void DeleteTree(int index)
Remove a tree from the ensemble.
void SetLeaf(int nid, tl_float value)
set the leaf value of the node
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
Setters.
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators