treelite
build.cc
Go to the documentation of this file.
1 
6 #include "./builder.h"
7 
8 namespace treelite {
9 namespace compiler {
10 
11 DMLC_REGISTRY_FILE_TAG(build);
12 
13 void ASTBuilder::BuildAST(const Model& model) {
14  this->output_vector_flag
15  = (model.num_output_group > 1 && model.random_forest_flag);
16  this->num_feature = model.num_feature;
17  this->num_output_group = model.num_output_group;
18  this->random_forest_flag = model.random_forest_flag;
19 
20  this->main_node = AddNode<MainNode>(nullptr, model.param.global_bias,
21  model.random_forest_flag,
22  model.trees.size(),
23  model.num_feature);
24  ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
25  this->main_node->children.push_back(ac);
26  for (int tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
27  ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], tree_id, ac);
28  ac->children.push_back(tree_head);
29  }
30  this->model_param = model.param.__DICT__();
31 }
32 
33 ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id,
34  ASTNode* parent) {
35  return BuildASTFromTree(tree, tree_id, 0, parent);
36 }
37 
38 ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id, int nid,
39  ASTNode* parent) {
40  const Tree::Node& node = tree[nid];
41  ASTNode* ast_node = nullptr;
42  if (node.is_leaf()) {
43  if (this->output_vector_flag) {
44  ast_node = AddNode<OutputNode>(parent, node.leaf_vector());
45  } else {
46  ast_node = AddNode<OutputNode>(parent, node.leaf_value());
47  }
48  } else {
49  if (node.split_type() == SplitFeatureType::kNumerical) {
50  ast_node = AddNode<NumericalConditionNode>(parent,
51  node.split_index(),
52  node.default_left(),
53  false,
54  node.comparison_op(),
55  ThresholdVariant(static_cast<tl_float>(node.threshold())));
56  } else {
57  ast_node = AddNode<CategoricalConditionNode>(parent,
58  node.split_index(),
59  node.default_left(),
60  node.left_categories());
61  }
62  if (node.has_gain()) {
63  dynamic_cast<ConditionNode*>(ast_node)->gain = node.gain();
64  }
65  ast_node->children.push_back(BuildASTFromTree(tree, tree_id,
66  node.cleft(), ast_node));
67  ast_node->children.push_back(BuildASTFromTree(tree, tree_id,
68  node.cright(), ast_node));
69  }
70  ast_node->node_id = nid;
71  ast_node->tree_id = tree_id;
72  if (node.has_data_count()) {
73  ast_node->data_count = node.data_count();
74  }
75  if (node.has_sum_hess()) {
76  ast_node->sum_hess = node.sum_hess();
77  }
78 
79  return ast_node;
80 }
81 
82 } // namespace compiler
83 } // namespace treelite
AST Builder class.