10 #include <dmlc/registry.h> 13 #include "../c_api/c_api_error.h" 15 #define CHECK_EARLY_RETURN(x, msg) \ 17 TreeliteAPISetLastError(msg); \ 18 dmlc::LogMessage(__FILE__, __LINE__).stream() << msg; \ 27 enum class _Status : int8_t {
28 kEmpty, kNumericalTest, kCategoricalTest, kLeaf
37 std::vector<treelite::tl_float> leaf_vector;
59 std::vector<uint32_t> left_categories;
62 : status(_Status::kEmpty),
63 parent(nullptr), left_child(nullptr), right_child(nullptr) {}
68 std::unordered_map<int, std::shared_ptr<_Node>> nodes;
69 inline _Tree() : root(nullptr), nodes() {}
77 DMLC_REGISTRY_FILE_TAG(builder);
85 std::vector<TreeBuilder> trees;
88 bool random_forest_flag;
89 std::vector<std::pair<std::string, std::string>> cfg;
91 bool random_forest_flag)
92 : trees(), num_feature(num_feature),
93 num_output_group(num_output_group),
94 random_forest_flag(random_forest_flag), cfg() {
95 CHECK_GT(num_feature, 0) <<
"ModelBuilder: num_feature must be positive";
96 CHECK_GT(num_output_group, 0)
97 <<
"ModelBuilder: num_output_group must be positive";
101 TreeBuilder::TreeBuilder()
102 : pimpl(common::make_unique<TreeBuilderImpl>()), ensemble_id(
nullptr) {}
103 TreeBuilder::~TreeBuilder() {}
106 TreeBuilder::CreateNode(
int node_key) {
107 auto& nodes = pimpl->tree.nodes;
108 CHECK_EARLY_RETURN(nodes.count(node_key) == 0,
109 "CreateNode: nodes with duplicate keys are not allowed");
110 nodes[node_key] = common::make_unique<_Node>();
115 TreeBuilder::DeleteNode(
int node_key) {
116 auto& tree = pimpl->tree;
117 auto& nodes = tree.nodes;
118 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
119 "DeleteNode: no node found with node_key");
120 _Node* node = nodes[node_key].get();
121 if (tree.root == node) {
124 if (node->left_child !=
nullptr) {
125 node->left_child->parent =
nullptr;
127 if (node->right_child !=
nullptr) {
128 node->right_child->parent =
nullptr;
130 nodes.erase(node_key);
135 TreeBuilder::SetRootNode(
int node_key) {
136 auto& tree = pimpl->tree;
137 auto& nodes = tree.nodes;
138 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
139 "SetRootNode: no node found with node_key");
140 _Node* node = nodes[node_key].get();
141 CHECK_EARLY_RETURN(node->parent ==
nullptr,
142 "SetRootNode: a root node cannot have a parent");
148 TreeBuilder::SetNumericalTestNode(
int node_key,
151 bool default_left,
int left_child_key,
152 int right_child_key) {
153 auto& tree = pimpl->tree;
154 auto& nodes = tree.nodes;
155 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
156 "SetNumericalTestNode: no node found with node_key");
157 CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
158 "SetNumericalTestNode: no node found with left_child_key");
159 CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
160 "SetNumericalTestNode: no node found with right_child_key");
161 _Node* node = nodes[node_key].get();
162 _Node* left_child = nodes[left_child_key].get();
163 _Node* right_child = nodes[right_child_key].get();
164 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
165 "SetNumericalTestNode: cannot modify a non-empty node");
166 CHECK_EARLY_RETURN(left_child->parent ==
nullptr,
167 "SetNumericalTestNode: node designated as left child already has " 169 CHECK_EARLY_RETURN(right_child->parent ==
nullptr,
170 "SetNumericalTestNode: node designated as right child already has " 172 CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
173 "SetNumericalTestNode: the root node cannot be a child");
174 node->status = _Node::_Status::kNumericalTest;
175 node->left_child = nodes[left_child_key].get();
176 node->left_child->parent = node;
177 node->right_child = nodes[right_child_key].get();
178 node->right_child->parent = node;
179 node->feature_id = feature_id;
180 node->default_left = default_left;
181 node->info.threshold = threshold;
187 TreeBuilder::SetCategoricalTestNode(
int node_key,
189 const std::vector<uint32_t>& left_categories,
190 bool default_left,
int left_child_key,
191 int right_child_key) {
192 auto& tree = pimpl->tree;
193 auto& nodes = tree.nodes;
194 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
195 "SetCategoricalTestNode: no node found with node_key");
196 CHECK_EARLY_RETURN(nodes.count(left_child_key) > 0,
197 "SetCategoricalTestNode: no node found with left_child_key");
198 CHECK_EARLY_RETURN(nodes.count(right_child_key) > 0,
199 "SetCategoricalTestNode: no node found with right_child_key");
200 _Node* node = nodes[node_key].get();
201 _Node* left_child = nodes[left_child_key].get();
202 _Node* right_child = nodes[right_child_key].get();
203 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
204 "SetCategoricalTestNode: cannot modify a non-empty node");
205 CHECK_EARLY_RETURN(left_child->parent ==
nullptr,
206 "SetCategoricalTestNode: node designated as left child already " 208 CHECK_EARLY_RETURN(right_child->parent ==
nullptr,
209 "SetCategoricalTestNode: node designated as right child already " 211 CHECK_EARLY_RETURN(left_child != tree.root && right_child != tree.root,
212 "SetCategoricalTestNode: the root node cannot be a child");
213 node->status = _Node::_Status::kCategoricalTest;
214 node->left_child = nodes[left_child_key].get();
215 node->left_child->parent = node;
216 node->right_child = nodes[right_child_key].get();
217 node->right_child->parent = node;
218 node->feature_id = feature_id;
219 node->default_left = default_left;
220 node->left_categories = left_categories;
225 TreeBuilder::SetLeafNode(
int node_key,
tl_float leaf_value) {
226 auto& tree = pimpl->tree;
227 auto& nodes = tree.nodes;
228 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
229 "SetLeafNode: no node found with node_key");
230 _Node* node = nodes[node_key].get();
231 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
232 "SetLeafNode: cannot modify a non-empty node");
233 node->status = _Node::_Status::kLeaf;
234 node->info.leaf_value = leaf_value;
239 TreeBuilder::SetLeafVectorNode(
int node_key,
240 const std::vector<tl_float>& leaf_vector) {
241 auto& tree = pimpl->tree;
242 auto& nodes = tree.nodes;
243 CHECK_EARLY_RETURN(nodes.count(node_key) > 0,
244 "SetLeafVectorNode: no node found with node_key");
245 _Node* node = nodes[node_key].get();
246 CHECK_EARLY_RETURN(node->status == _Node::_Status::kEmpty,
247 "SetLeafVectorNode: cannot modify a non-empty node");
248 node->status = _Node::_Status::kLeaf;
249 node->leaf_vector = leaf_vector;
253 ModelBuilder::ModelBuilder(
int num_feature,
int num_output_group,
254 bool random_forest_flag)
257 random_forest_flag)) {}
258 ModelBuilder::~ModelBuilder() {}
262 pimpl->cfg.emplace_back(name, value);
267 if (tree_builder ==
nullptr) {
268 const char* msg =
"InsertTree: not a valid tree builder";
270 TreeliteAPISetLastError(msg);
273 if (tree_builder->ensemble_id !=
nullptr) {
274 const char* msg =
"InsertTree: tree is already part of another ensemble";
276 TreeliteAPISetLastError(msg);
281 for (
const auto& kv : tree_builder->pimpl->tree.nodes) {
282 const _Node::_Status status = kv.second->status;
283 if (status == _Node::_Status::kNumericalTest ||
284 status == _Node::_Status::kCategoricalTest) {
285 const int fid =
static_cast<int>(kv.second->feature_id);
286 if (fid < 0 || fid >= pimpl->num_feature) {
287 std::ostringstream oss;
288 oss <<
"InsertTree: tree has an invalid split at node " 289 << kv.first <<
": feature id " << kv.second->feature_id
290 <<
" is out of bound";
291 const std::string str = oss.str();
292 const char* msg = str.c_str();
294 TreeliteAPISetLastError(msg);
301 auto& trees = pimpl->trees;
303 trees.push_back(std::move(*tree_builder));
304 tree_builder->ensemble_id =
static_cast<void*
>(
this);
305 return static_cast<int>(trees.size());
307 if (static_cast<size_t>(index) <= trees.size()) {
308 trees.insert(trees.begin() + index, std::move(*tree_builder));
309 tree_builder->ensemble_id =
static_cast<void*
>(
this);
312 LOG(INFO) <<
"CreateTree: index out of bound";
320 return pimpl->trees[index];
325 return pimpl->trees[index];
330 auto& trees = pimpl->trees;
331 CHECK_EARLY_RETURN(static_cast<size_t>(index) < trees.size(),
332 "DeleteTree: index out of bound");
333 trees.erase(trees.begin() + index);
344 InitParamAndCheck(&model.
param, pimpl->cfg);
350 int8_t flag_leaf_vector = -1;
352 for (
const auto& _tree_builder : pimpl->trees) {
353 const auto& _tree = _tree_builder.pimpl->tree;
354 CHECK_EARLY_RETURN(_tree.root !=
nullptr,
355 "CommitModel: a tree has no root node");
356 model.
trees.emplace_back();
362 std::queue<std::pair<const _Node*, int>> Q;
363 Q.push({_tree.root, 0});
367 std::tie(node, nid) = Q.front(); Q.pop();
368 CHECK_EARLY_RETURN(node->status != _Node::_Status::kEmpty,
369 "CommitModel: encountered an empty node in the middle of a tree");
370 if (node->status == _Node::_Status::kNumericalTest) {
371 CHECK_EARLY_RETURN(node->left_child !=
nullptr,
372 "CommitModel: a test node lacks a left child");
373 CHECK_EARLY_RETURN(node->right_child !=
nullptr,
374 "CommitModel: a test node lacks a right child");
375 CHECK_EARLY_RETURN(node->left_child->parent == node,
376 "CommitModel: left child has wrong parent");
377 CHECK_EARLY_RETURN(node->right_child->parent == node,
378 "CommitModel: right child has wrong parent");
380 tree[nid].set_numerical_split(node->feature_id, node->info.threshold,
381 node->default_left, node->op);
382 Q.push({node->left_child, tree[nid].cleft()});
383 Q.push({node->right_child, tree[nid].cright()});
384 }
else if (node->status == _Node::_Status::kCategoricalTest) {
385 CHECK_EARLY_RETURN(node->left_child !=
nullptr,
386 "CommitModel: a test node lacks a left child");
387 CHECK_EARLY_RETURN(node->right_child !=
nullptr,
388 "CommitModel: a test node lacks a right child");
389 CHECK_EARLY_RETURN(node->left_child->parent == node,
390 "CommitModel: left child has wrong parent");
391 CHECK_EARLY_RETURN(node->right_child->parent == node,
392 "CommitModel: right child has wrong parent");
394 tree[nid].set_categorical_split(node->feature_id, node->default_left,
395 false, node->left_categories);
396 Q.push({node->left_child, tree[nid].cleft()});
397 Q.push({node->right_child, tree[nid].cright()});
399 CHECK_EARLY_RETURN(node->left_child ==
nullptr 400 && node->right_child ==
nullptr,
401 "CommitModel: a leaf node cannot have children");
402 if (!node->leaf_vector.empty()) {
403 CHECK_EARLY_RETURN(flag_leaf_vector != 0,
404 "CommitModel: Inconsistent use of leaf vector: " 405 "if one leaf node uses a leaf vector, " 406 "*every* leaf node must use a leaf vector");
407 flag_leaf_vector = 1;
409 "CommitModel: The length of leaf vector must be " 410 "identical to the number of output groups");
411 tree[nid].set_leaf_vector(node->leaf_vector);
413 CHECK_EARLY_RETURN(flag_leaf_vector != 1,
414 "CommitModel: Inconsistent use of leaf vector: " 415 "if one leaf node does not use a leaf vector, " 416 "*no other* leaf node can use a leaf vector");
417 flag_leaf_vector = 0;
418 tree[nid].set_leaf(node->info.leaf_value);
423 if (flag_leaf_vector == 0) {
427 "To use a random forest for multi-class classification, each leaf " 428 "node must output a leaf vector specifying a probability " 431 "For multi-class classifiers with gradient boosted trees, the number " 432 "of trees must be evenly divisible by the number of output groups");
434 }
else if (flag_leaf_vector == 1) {
437 "In multi-class classifiers with gradient boosted trees, each leaf " 438 "node must output a single floating-point value.");
440 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
442 *out_model = std::move(model);
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
std::vector< Tree > trees
member trees
ModelParam param
extra parameters
void SetModelParam(const char *name, const char *value)
Set a model parameter.
in-memory representation of a decision tree
bool DeleteTree(int index)
Remove a tree from the ensemble.
int InsertTree(TreeBuilder *tree_builder, int index=-1)
Insert a tree at specified location.
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
double tl_float
float type to be used internally
void AddChilds(int nid)
add child nodes to node
bool CommitModel(Model *out_model)
finalize the model and produce the in-memory representation
TreeBuilder & GetTree(int index)
Get a reference to a tree in the ensemble.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators