Treelite
build.cc
Go to the documentation of this file.
1 
6 #include <dmlc/registry.h>
7 #include "./builder.h"
8 
9 namespace treelite {
10 namespace compiler {
11 
12 DMLC_REGISTRY_FILE_TAG(build);
13 
14 void ASTBuilder::BuildAST(const Model& model) {
15  this->output_vector_flag
16  = (model.num_output_group > 1 && model.random_forest_flag);
17  this->num_feature = model.num_feature;
18  this->num_output_group = model.num_output_group;
19  this->random_forest_flag = model.random_forest_flag;
20 
21  this->main_node = AddNode<MainNode>(nullptr, model.param.global_bias,
22  model.random_forest_flag,
23  model.trees.size(),
24  model.num_feature);
25  ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
26  this->main_node->children.push_back(ac);
27  for (int tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
28  ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], tree_id, ac);
29  ac->children.push_back(tree_head);
30  }
31  this->model_param = model.param.__DICT__();
32 }
33 
34 ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id,
35  ASTNode* parent) {
36  return BuildASTFromTree(tree, tree_id, 0, parent);
37 }
38 
39 ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id, int nid,
40  ASTNode* parent) {
41  ASTNode* ast_node = nullptr;
42  if (tree.IsLeaf(nid)) {
43  if (this->output_vector_flag) {
44  ast_node = AddNode<OutputNode>(parent, tree.LeafVector(nid));
45  } else {
46  ast_node = AddNode<OutputNode>(parent, tree.LeafValue(nid));
47  }
48  } else {
49  if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
50  ast_node = AddNode<NumericalConditionNode>(parent,
51  tree.SplitIndex(nid),
52  tree.DefaultLeft(nid),
53  false,
54  tree.ComparisonOp(nid),
55  ThresholdVariant(static_cast<tl_float>(
56  tree.Threshold(nid))));
57  } else {
58  ast_node = AddNode<CategoricalConditionNode>(parent,
59  tree.SplitIndex(nid),
60  tree.DefaultLeft(nid),
61  tree.LeftCategories(nid),
62  tree.MissingCategoryToZero(nid));
63  }
64  if (tree.HasGain(nid)) {
65  dynamic_cast<ConditionNode*>(ast_node)->gain = tree.Gain(nid);
66  }
67  ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.LeftChild(nid), ast_node));
68  ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.RightChild(nid), ast_node));
69  }
70  ast_node->node_id = nid;
71  ast_node->tree_id = tree_id;
72  if (tree.HasDataCount(nid)) {
73  ast_node->data_count = tree.DataCount(nid);
74  }
75  if (tree.HasSumHess(nid)) {
76  ast_node->sum_hess = tree.SumHess(nid);
77  }
78 
79  return ast_node;
80 }
81 
82 } // namespace compiler
83 } // namespace treelite
AST Builder class.