treelite
build.cc
1 #include "./builder.h"
2 
3 namespace treelite {
4 namespace compiler {
5 
6 DMLC_REGISTRY_FILE_TAG(build);
7 
8 void ASTBuilder::Build(const Model& model) {
9  this->output_vector_flag
10  = (model.num_output_group > 1 && model.random_forest_flag);
11  this->num_feature = model.num_feature;
12 
13  this->main_node = AddNode<MainNode>(nullptr, model.param.global_bias,
14  model.random_forest_flag,
15  model.trees.size(),
16  model.num_feature);
17  ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
18  this->main_node->children.push_back(ac);
19  for (int tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
20  ASTNode* tree_head = WalkTree(model.trees[tree_id], ac);
21  /* store tree ID in descendant nodes */
22  std::function<void(ASTNode*)> func;
23  func = [tree_id, &func](ASTNode* node) -> void {
24  node->tree_id = tree_id;
25  for (ASTNode* child : node->children) {
26  func(child);
27  }
28  };
29  func(tree_head);
30  ac->children.push_back(tree_head);
31  }
32 }
33 
34 ASTNode* ASTBuilder::WalkTree(const Tree& tree, ASTNode* parent) {
35  return WalkTree(tree, 0, parent);
36 }
37 
38 ASTNode* ASTBuilder::WalkTree(const Tree& tree, int nid, ASTNode* parent) {
39  const Tree::Node& node = tree[nid];
40  ASTNode* ast_node = nullptr;
41  if (node.is_leaf()) {
42  if (this->output_vector_flag) {
43  ast_node = AddNode<OutputNode>(parent, node.leaf_vector());
44  } else {
45  ast_node = AddNode<OutputNode>(parent, node.leaf_value());
46  }
47  } else {
48  if (node.split_type() == SplitFeatureType::kNumerical) {
49  ast_node = AddNode<NumericalConditionNode>(parent,
50  node.split_index(),
51  node.default_left(),
52  false,
53  node.comparison_op(),
54  ThresholdVariant(static_cast<tl_float>(node.threshold())));
55  } else {
56  ast_node = AddNode<CategoricalConditionNode>(parent,
57  node.split_index(),
58  node.default_left(),
59  node.left_categories());
60  }
61  ast_node->children.push_back(WalkTree(tree, node.cleft(), ast_node));
62  ast_node->children.push_back(WalkTree(tree, node.cright(), ast_node));
63  }
64  ast_node->node_id = nid;
65 
66  return ast_node;
67 }
68 
69 } // namespace compiler
70 } // namespace treelite