Treelite
build.cc
Go to the documentation of this file.
1 
6 #include "./builder.h"
7 
8 namespace treelite {
9 namespace compiler {
10 
11 template <typename ThresholdType, typename LeafOutputType>
12 void
13 ASTBuilder<ThresholdType, LeafOutputType>::BuildAST(
14  const ModelImpl<ThresholdType, LeafOutputType>& model) {
15  this->output_vector_flag = (model.task_param.leaf_vector_size > 1);
16  this->num_feature = model.num_feature;
17  this->average_output_flag = model.average_tree_output;
18 
19  this->main_node = AddNode<MainNode>(nullptr, model.param.global_bias,
20  model.average_tree_output,
21  static_cast<int>(model.trees.size()),
22  model.num_feature);
23  ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
24  this->main_node->children.push_back(ac);
25  for (size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
26  ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], static_cast<int>(tree_id), 0, ac);
27  ac->children.push_back(tree_head);
28  }
29  this->model_param = model.param.__DICT__();
30 }
31 
32 template <typename ThresholdType, typename LeafOutputType>
33 ASTNode*
34 ASTBuilder<ThresholdType, LeafOutputType>::BuildASTFromTree(
35  const Tree<ThresholdType, LeafOutputType>& tree, int tree_id, int nid, ASTNode* parent) {
36  ASTNode* ast_node = nullptr;
37  if (tree.IsLeaf(nid)) {
38  if (this->output_vector_flag) {
39  ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafVector(nid));
40  } else {
41  ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafValue(nid));
42  }
43  } else {
44  if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
45  ast_node = AddNode<NumericalConditionNode<ThresholdType>>(
46  parent,
47  tree.SplitIndex(nid),
48  tree.DefaultLeft(nid),
49  false,
50  tree.ComparisonOp(nid),
51  ThresholdVariant<ThresholdType>(tree.Threshold(nid)));
52  } else {
53  ast_node = AddNode<CategoricalConditionNode>(
54  parent,
55  tree.SplitIndex(nid),
56  tree.DefaultLeft(nid),
57  tree.MatchingCategories(nid),
58  tree.CategoriesListRightChild(nid));
59  }
60  if (tree.HasGain(nid)) {
61  dynamic_cast<ConditionNode*>(ast_node)->gain = tree.Gain(nid);
62  }
63  ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.LeftChild(nid), ast_node));
64  ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.RightChild(nid), ast_node));
65  }
66  ast_node->node_id = nid;
67  ast_node->tree_id = tree_id;
68  if (tree.HasDataCount(nid)) {
69  ast_node->data_count = tree.DataCount(nid);
70  }
71  if (tree.HasSumHess(nid)) {
72  ast_node->sum_hess = tree.SumHess(nid);
73  }
74 
75  return ast_node;
76 }
77 
78 
79 template void ASTBuilder<float, uint32_t>::BuildAST(const ModelImpl<float, uint32_t>&);
80 template void ASTBuilder<float, float>::BuildAST(const ModelImpl<float, float>&);
81 template void ASTBuilder<double, uint32_t>::BuildAST(const ModelImpl<double, uint32_t>&);
82 template void ASTBuilder<double, double>::BuildAST(const ModelImpl<double, double>&);
83 template ASTNode* ASTBuilder<float, uint32_t>::BuildASTFromTree(
84  const Tree<float, uint32_t>&, int, int, ASTNode*);
85 template ASTNode* ASTBuilder<float, float>::BuildASTFromTree(
86  const Tree<float, float>&, int, int, ASTNode*);
87 template ASTNode* ASTBuilder<double, uint32_t>::BuildASTFromTree(
88  const Tree<double, uint32_t>&, int, int, ASTNode*);
89 template ASTNode* ASTBuilder<double, double>::BuildASTFromTree(
90  const Tree<double, double>&, int, int, ASTNode*);
91 
92 } // namespace compiler
93 } // namespace treelite
AST Builder class.