10 #include <dmlc/registry.h> 13 #include "../c_api/c_api_error.h" 15 #define CHECK_EARLY_RETURN(x, msg) \ 17 TreeliteAPISetLastError(msg); \ 18 dmlc::LogMessage(__FILE__, __LINE__).stream() << msg; \ 27 enum class _Status : int8_t {
28 kEmpty, kNumericalTest, kCategoricalTest, kLeaf
37 std::vector<treelite::tl_float> leaf_vector;
59 std::vector<uint32_t> left_categories;
62 : status(_Status::kEmpty),
63 parent(nullptr), left_child(nullptr), right_child(nullptr) {}
68 std::unordered_map<int, std::shared_ptr<_Node>> nodes;
69 inline _Tree() : root(nullptr), nodes() {}
77 DMLC_REGISTRY_FILE_TAG(builder);
85 std::vector<TreeBuilder> trees;
88 bool random_forest_flag;
89 std::vector<std::pair<std::string, std::string>> cfg;
91 bool random_forest_flag)
92 : trees(), num_feature(num_feature),
93 num_output_group(num_output_group),
94 random_forest_flag(random_forest_flag), cfg() {
95 CHECK_GT(num_feature, 0) <<
"ModelBuilder: num_feature must be positive";
96 CHECK_GT(num_output_group, 0)
97 <<
"ModelBuilder: num_output_group must be positive";
101 TreeBuilder::TreeBuilder()
102 : pimpl(common::make_unique<TreeBuilderImpl>()), ensemble_id(
nullptr) {}
103 TreeBuilder::~TreeBuilder() {}
106 TreeBuilder::CreateNode(
int node_key) {
107 auto& nodes = pimpl->tree.nodes;
108 CHECK_EARLY_RETURN(nodes.count(node_key) == 0,
109 "CreateNode: nodes with duplicate keys are not allowed");
110 nodes[node_key] = common::make_unique<_Node>();
115 TreeBuilder::DeleteNode(
int node_key) {
116 auto& tree = pimpl->tree;
117 auto& nodes = tree.nodes;
118 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
119 "DeleteNode: no node found with node_key");
120 _Node* node = nodes[node_key].get();
121 if (tree.root == node) {
124 if (node->left_child !=
nullptr) {
125 node->left_child->parent =
nullptr;
127 if (node->right_child !=
nullptr) {
128 node->right_child->parent =
nullptr;
130 nodes.erase(node_key);
135 TreeBuilder::SetRootNode(
int node_key) {
136 auto& tree = pimpl->tree;
137 auto& nodes = tree.nodes;
138 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
139 "SetRootNode: no node found with node_key");
140 _Node* node = nodes[node_key].get();
141 CHECK_EARLY_RETURN(node->status != _Node::_Status::kLeaf,
142 "SetRootNode: cannot set a leaf node as root");
143 CHECK_EARLY_RETURN(node->parent ==
nullptr,
144 "SetRootNode: a root node cannot have a parent");
150 TreeBuilder::SetNumericalTestNode(
int node_key,
153 bool default_left,
int left_child_key,
154 int right_child_key) {
155 auto& tree = pimpl->tree;
156 auto& nodes = tree.nodes;
157 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
158 "SetNumericalTestNode: no node found with node_key");
159 CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
160 "SetNumericalTestNode: no node found with left_child_key");
161 CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
162 "SetNumericalTestNode: no node found with right_child_key");
163 _Node* node = nodes[node_key].get();
164 _Node* left_child = nodes[left_child_key].get();
165 _Node* right_child = nodes[right_child_key].get();
166 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
167 "SetNumericalTestNode: cannot modify a non-empty node");
168 CHECK_EARLY_RETURN(left_child->parent ==
nullptr,
169 "SetNumericalTestNode: node designated as left child already has " 171 CHECK_EARLY_RETURN(right_child->parent ==
nullptr,
172 "SetNumericalTestNode: node designated as right child already has " 174 CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
175 "SetNumericalTestNode: the root node cannot be a child");
176 node->status = _Node::_Status::kNumericalTest;
177 node->left_child = nodes[left_child_key].get();
178 node->left_child->parent = node;
179 node->right_child = nodes[right_child_key].get();
180 node->right_child->parent = node;
181 node->feature_id = feature_id;
182 node->default_left = default_left;
183 node->info.threshold = threshold;
189 TreeBuilder::SetCategoricalTestNode(
int node_key,
191 const std::vector<uint32_t>& left_categories,
192 bool default_left,
int left_child_key,
193 int right_child_key) {
194 auto& tree = pimpl->tree;
195 auto& nodes = tree.nodes;
196 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
197 "SetCategoricalTestNode: no node found with node_key");
198 CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
199 "SetCategoricalTestNode: no node found with left_child_key");
200 CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
201 "SetCategoricalTestNode: no node found with right_child_key");
202 _Node* node = nodes[node_key].get();
203 _Node* left_child = nodes[left_child_key].get();
204 _Node* right_child = nodes[right_child_key].get();
205 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
206 "SetCategoricalTestNode: cannot modify a non-empty node");
207 CHECK_EARLY_RETURN(left_child->parent ==
nullptr,
208 "SetCategoricalTestNode: node designated as left child already " 210 CHECK_EARLY_RETURN(right_child->parent ==
nullptr,
211 "SetCategoricalTestNode: node designated as right child already " 213 CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
214 "SetCategoricalTestNode: the root node cannot be a child");
215 node->status = _Node::_Status::kCategoricalTest;
216 node->left_child = nodes[left_child_key].get();
217 node->left_child->parent = node;
218 node->right_child = nodes[right_child_key].get();
219 node->right_child->parent = node;
220 node->feature_id = feature_id;
221 node->default_left = default_left;
222 node->left_categories = left_categories;
227 TreeBuilder::SetLeafNode(
int node_key,
tl_float leaf_value) {
228 auto& tree = pimpl->tree;
229 auto& nodes = tree.nodes;
230 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
231 "SetLeafNode: no node found with node_key");
232 _Node* node = nodes[node_key].get();
233 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
234 "SetLeafNode: cannot modify a non-empty node");
235 node->status = _Node::_Status::kLeaf;
236 node->info.leaf_value = leaf_value;
241 TreeBuilder::SetLeafVectorNode(
int node_key,
242 const std::vector<tl_float>& leaf_vector) {
243 auto& tree = pimpl->tree;
244 auto& nodes = tree.nodes;
245 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
246 "SetLeafVectorNode: no node found with node_key");
247 _Node* node = nodes[node_key].get();
248 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
249 "SetLeafVectorNode: cannot modify a non-empty node");
250 node->status = _Node::_Status::kLeaf;
251 node->leaf_vector = leaf_vector;
255 ModelBuilder::ModelBuilder(
int num_feature,
int num_output_group,
256 bool random_forest_flag)
259 random_forest_flag)) {}
260 ModelBuilder::~ModelBuilder() {}
264 pimpl->cfg.emplace_back(name, value);
269 if (tree_builder ==
nullptr) {
270 const char* msg =
"InsertTree: not a valid tree builder";
275 if (tree_builder->ensemble_id !=
nullptr) {
276 const char* msg =
"InsertTree: tree is already part of another ensemble";
283 for (
const auto& kv : tree_builder->pimpl->tree.nodes) {
284 const _Node::_Status status = kv.second->status;
285 if (status == _Node::_Status::kNumericalTest ||
286 status == _Node::_Status::kCategoricalTest) {
287 const int fid =
static_cast<int>(kv.second->feature_id);
288 if (fid < 0 || fid >= pimpl->num_feature) {
289 std::ostringstream oss;
290 oss <<
"InsertTree: tree has an invalid split at node " 291 << kv.first <<
": feature id " << kv.second->feature_id
292 <<
" is out of bound";
293 const std::string str = oss.str();
294 const char* msg = str.c_str();
303 auto& trees = pimpl->trees;
305 trees.push_back(std::move(*tree_builder));
306 tree_builder->ensemble_id =
static_cast<void*
>(
this);
307 return static_cast<int>(trees.size());
309 if (static_cast<size_t>(index) <= trees.size()) {
310 trees.insert(trees.begin() + index, std::move(*tree_builder));
311 tree_builder->ensemble_id =
static_cast<void*
>(
this);
314 LOG(INFO) <<
"CreateTree: index out of bound";
322 return pimpl->trees[index];
327 return pimpl->trees[index];
332 auto& trees = pimpl->trees;
333 CHECK_EARLY_RETURN(static_cast<size_t>(index) < trees.size(),
334 "DeleteTree: index out of bound");
335 trees.erase(trees.begin() + index);
346 InitParamAndCheck(&model.
param, pimpl->cfg);
352 int8_t flag_leaf_vector = -1;
354 for (
const auto& _tree_builder : pimpl->trees) {
355 const auto& _tree = _tree_builder.pimpl->tree;
356 CHECK_EARLY_RETURN(_tree.root !=
nullptr,
357 "CommitModel: a tree has no root node");
358 model.
trees.emplace_back();
364 std::queue<std::pair<const _Node*, int>> Q;
365 Q.push({_tree.root, 0});
369 std::tie(node, nid) = Q.front(); Q.pop();
370 CHECK_EARLY_RETURN(node->status != _Node::_Status::kEmpty,
371 "CommitModel: encountered an empty node in the middle of a tree");
372 if (node->status == _Node::_Status::kNumericalTest) {
373 CHECK_EARLY_RETURN(node->left_child !=
nullptr,
374 "CommitModel: a test node lacks a left child");
375 CHECK_EARLY_RETURN(node->right_child !=
nullptr,
376 "CommitModel: a test node lacks a right child");
377 CHECK_EARLY_RETURN(node->left_child->parent == node,
378 "CommitModel: left child has wrong parent");
379 CHECK_EARLY_RETURN(node->right_child->parent == node,
380 "CommitModel: right child has wrong parent");
382 tree[nid].set_numerical_split(node->feature_id, node->info.threshold,
383 node->default_left, node->op);
384 Q.push({node->left_child, tree[nid].cleft()});
385 Q.push({node->right_child, tree[nid].cright()});
386 }
else if (node->status == _Node::_Status::kCategoricalTest) {
387 CHECK_EARLY_RETURN(node->left_child !=
nullptr,
388 "CommitModel: a test node lacks a left child");
389 CHECK_EARLY_RETURN(node->right_child !=
nullptr,
390 "CommitModel: a test node lacks a right child");
391 CHECK_EARLY_RETURN(node->left_child->parent == node,
392 "CommitModel: left child has wrong parent");
393 CHECK_EARLY_RETURN(node->right_child->parent == node,
394 "CommitModel: right child has wrong parent");
396 tree[nid].set_categorical_split(node->feature_id, node->default_left,
397 node->left_categories);
398 Q.push({node->left_child, tree[nid].cleft()});
399 Q.push({node->right_child, tree[nid].cright()});
401 CHECK_EARLY_RETURN(node->left_child ==
nullptr 402 && node->right_child ==
nullptr,
403 "CommitModel: a leaf node cannot have children");
404 if (!node->leaf_vector.empty()) {
405 CHECK_EARLY_RETURN(flag_leaf_vector != 0,
406 "CommitModel: Inconsistent use of leaf vector: " 407 "if one leaf node uses a leaf vector, " 408 "*every* leaf node must use a leaf vector");
409 flag_leaf_vector = 1;
411 "CommitModel: The length of leaf vector must be " 412 "identical to the number of output groups");
413 tree[nid].set_leaf_vector(node->leaf_vector);
415 CHECK_EARLY_RETURN(flag_leaf_vector != 1,
416 "CommitModel: Inconsistent use of leaf vector: " 417 "if one leaf node does not use a leaf vector, " 418 "*no other* leaf node can use a leaf vector");
419 flag_leaf_vector = 0;
420 tree[nid].set_leaf(node->info.leaf_value);
425 if (flag_leaf_vector == 0) {
429 "To use a random forest for multi-class classification, each leaf " 430 "node must output a leaf vector specifying a probability " 433 "For multi-class classifiers with gradient boosted trees, the number " 434 "of trees must be evenly divisible by the number of output groups");
436 }
else if (flag_leaf_vector == 1) {
439 "In multi-class classifiers with gradient boosted trees, each leaf " 440 "node must output a single floating-point value.");
442 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
444 *out_model = std::move(model);
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::vector< Tree > trees
member trees
ModelParam param
extra parameters
void SetModelParam(const char *name, const char *value)
Set a model parameter.
in-memory representation of a decision tree
bool DeleteTree(int index)
Remove a tree from the ensemble.
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
void TreeliteAPISetLastError(const char *msg)
Set the last error message needed by C API.
void AddChilds(int nid)
add child nodes to node
bool CommitModel(Model *out_model)
finalize the model and produce the in-memory representation
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators