Treelite
build.cc
Go to the documentation of this file.
1 
6 #include <dmlc/registry.h>
7 #include "./builder.h"
8 
9 namespace treelite {
10 namespace compiler {
11 
12 DMLC_REGISTRY_FILE_TAG(build);
13 
14 template <typename ThresholdType, typename LeafOutputType>
15 void
16 ASTBuilder<ThresholdType, LeafOutputType>::BuildAST(
17  const ModelImpl<ThresholdType, LeafOutputType>& model) {
18  this->output_vector_flag = (model.task_param.leaf_vector_size > 1);
19  this->num_feature = model.num_feature;
20  this->average_output_flag = model.average_tree_output;
21 
22  this->main_node = AddNode<MainNode>(nullptr, model.param.global_bias,
23  model.average_tree_output,
24  static_cast<int>(model.trees.size()),
25  model.num_feature);
26  ASTNode* ac = AddNode<AccumulatorContextNode>(this->main_node);
27  this->main_node->children.push_back(ac);
28  for (size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
29  ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], static_cast<int>(tree_id), 0, ac);
30  ac->children.push_back(tree_head);
31  }
32  this->model_param = model.param.__DICT__();
33 }
34 
35 template <typename ThresholdType, typename LeafOutputType>
36 ASTNode*
37 ASTBuilder<ThresholdType, LeafOutputType>::BuildASTFromTree(
38  const Tree<ThresholdType, LeafOutputType>& tree, int tree_id, int nid, ASTNode* parent) {
39  ASTNode* ast_node = nullptr;
40  if (tree.IsLeaf(nid)) {
41  if (this->output_vector_flag) {
42  ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafVector(nid));
43  } else {
44  ast_node = AddNode<OutputNode<LeafOutputType>>(parent, tree.LeafValue(nid));
45  }
46  } else {
47  if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
48  ast_node = AddNode<NumericalConditionNode<ThresholdType>>(
49  parent,
50  tree.SplitIndex(nid),
51  tree.DefaultLeft(nid),
52  false,
53  tree.ComparisonOp(nid),
54  ThresholdVariant<ThresholdType>(tree.Threshold(nid)));
55  } else {
56  ast_node = AddNode<CategoricalConditionNode>(
57  parent,
58  tree.SplitIndex(nid),
59  tree.DefaultLeft(nid),
60  tree.MatchingCategories(nid),
61  tree.CategoriesListRightChild(nid));
62  }
63  if (tree.HasGain(nid)) {
64  dynamic_cast<ConditionNode*>(ast_node)->gain = tree.Gain(nid);
65  }
66  ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.LeftChild(nid), ast_node));
67  ast_node->children.push_back(BuildASTFromTree(tree, tree_id, tree.RightChild(nid), ast_node));
68  }
69  ast_node->node_id = nid;
70  ast_node->tree_id = tree_id;
71  if (tree.HasDataCount(nid)) {
72  ast_node->data_count = tree.DataCount(nid);
73  }
74  if (tree.HasSumHess(nid)) {
75  ast_node->sum_hess = tree.SumHess(nid);
76  }
77 
78  return ast_node;
79 }
80 
81 
82 template void ASTBuilder<float, uint32_t>::BuildAST(const ModelImpl<float, uint32_t>&);
83 template void ASTBuilder<float, float>::BuildAST(const ModelImpl<float, float>&);
84 template void ASTBuilder<double, uint32_t>::BuildAST(const ModelImpl<double, uint32_t>&);
85 template void ASTBuilder<double, double>::BuildAST(const ModelImpl<double, double>&);
86 template ASTNode* ASTBuilder<float, uint32_t>::BuildASTFromTree(
87  const Tree<float, uint32_t>&, int, int, ASTNode*);
88 template ASTNode* ASTBuilder<float, float>::BuildASTFromTree(
89  const Tree<float, float>&, int, int, ASTNode*);
90 template ASTNode* ASTBuilder<double, uint32_t>::BuildASTFromTree(
91  const Tree<double, uint32_t>&, int, int, ASTNode*);
92 template ASTNode* ASTBuilder<double, double>::BuildASTFromTree(
93  const Tree<double, double>&, int, int, ASTNode*);
94 
95 } // namespace compiler
96 } // namespace treelite
AST Builder class.