treelite
ast.h
1 #ifndef TREELITE_COMPILER_AST_AST_H_
2 #define TREELITE_COMPILER_AST_AST_H_
3 
4 namespace treelite {
5 namespace compiler {
6 
12 inline std::string OpName(Operator op) {
13  switch(op) {
14  case Operator::kEQ: return "==";
15  case Operator::kLT: return "<";
16  case Operator::kLE: return "<=";
17  case Operator::kGT: return ">";
18  case Operator::kGE: return ">=";
19  default: return "";
20  }
21 }
22 
24 enum class BranchHint : uint8_t {
25  kNone = 0,
26  kLikely = 1,
27  kUnlikely = 2
28 };
29 
30 inline std::string BranchHintName(BranchHint hint) {
31  switch (hint) {
32  case BranchHint::kNone:
33  return "kNone";
34  case BranchHint::kLikely:
35  return "kLikely";
36  case BranchHint::kUnlikely:
37  return "kUnlikely";
38  }
39  LOG(FATAL) << "Unrecognized BranchHint value";
40 }
41 
42 class ASTNode {
43  public:
44  ASTNode* parent;
45  std::vector<ASTNode*> children;
46  int node_id;
47  int tree_id;
48  int num_descendant;
49  virtual void Dump(int indent) = 0;
50 };
51 
52 class MainNode : public ASTNode {
53  public:
54  MainNode(tl_float global_bias, bool average_result, int num_tree,
55  int num_feature)
56  : global_bias(global_bias), average_result(average_result),
57  num_tree(num_tree), num_feature(num_feature) {}
58  void Dump(int indent) override {
59  std::cerr << std::string(indent, ' ') << std::boolalpha
60  << "MainNode {"
61  << "global_bias: " << this->global_bias << ", "
62  << "average_result: " << this->average_result << ", "
63  << "num_tree: " << this->num_tree << ", "
64  << "num_feature: " << this->num_feature << "}"
65  << std::endl;
66  }
67  tl_float global_bias;
68  bool average_result;
69  int num_tree;
70  int num_feature;
71 };
72 
73 class TranslationUnitNode : public ASTNode {
74  public:
75  TranslationUnitNode(int unit_id) : unit_id(unit_id) {}
76  void Dump(int indent) override {
77  std::cerr << std::string(indent, ' ')
78  << "TranslationUnitNode {"
79  << "unit_id: " << unit_id << "}"
80  << std::endl;
81  }
82  int unit_id;
83 };
84 
85 class QuantizerNode : public ASTNode {
86  public:
87  QuantizerNode(const std::vector<std::vector<tl_float>>& cut_pts,
88  const std::vector<bool>& is_categorical)
89  : cut_pts(cut_pts), is_categorical(is_categorical) {}
90  QuantizerNode(std::vector<std::vector<tl_float>>&& cut_pts,
91  std::vector<bool>&& is_categorical)
92  : cut_pts(std::move(cut_pts)), is_categorical(std::move(is_categorical)) {}
93  std::vector<std::vector<tl_float>> cut_pts;
94  std::vector<bool> is_categorical;
95  void Dump(int indent) override {
96  std::cerr << std::string(indent, ' ')
97  << "QuantizerNode = {}" << std::endl;
98  }
99 };
100 
102  public:
104  void Dump(int indent) override {
105  std::cerr << std::string(indent, ' ')
106  << "AccumulatorContextNode = {}" << std::endl;
107  }
108 };
109 
110 class ConditionNode : public ASTNode {
111  public:
112  ConditionNode(unsigned split_index, bool default_left, BranchHint branch_hint)
113  : split_index(split_index), default_left(default_left),
114  branch_hint(branch_hint) {}
115  unsigned split_index;
116  bool default_left;
117  BranchHint branch_hint;
118 };
119 
121  tl_float float_val;
122  int int_val;
123  ThresholdVariant(tl_float val) : float_val(val) {}
124  ThresholdVariant(int val) : int_val(val) {}
125 };
126 
128  public:
129  NumericalConditionNode(unsigned split_index, bool default_left,
130  bool quantized, Operator op,
131  ThresholdVariant threshold,
132  BranchHint branch_hint = BranchHint::kNone)
133  : ConditionNode(split_index, default_left, branch_hint),
134  quantized(quantized), op(op), threshold(threshold) {}
135  void Dump(int indent) override {
136  std::cerr << std::string(indent, ' ') << std::boolalpha
137  << "NumericalConditionNode {"
138  << "split_index: " << this->split_index << ", "
139  << "default_left: " << this->default_left << ", "
140  << "quantized: " << this->quantized << ", "
141  << "op: " << OpName(this->op) << ", "
142  << "threshold: " << (quantized ? this->threshold.int_val
143  : this->threshold.float_val) << ", "
144  << "branch_hint: " << BranchHintName(this->branch_hint)
145  << "}" << std::endl;
146  }
147  bool quantized;
148  Operator op;
149  ThresholdVariant threshold;
150 };
151 
153  public:
154  CategoricalConditionNode(unsigned split_index, bool default_left,
155  const std::vector<uint32_t>& left_categories,
156  BranchHint branch_hint = BranchHint::kNone)
157  : ConditionNode(split_index, default_left, branch_hint),
158  left_categories(left_categories) {}
159  void Dump(int indent) override {
160  std::ostringstream oss;
161  for (uint32_t e : this->left_categories) {
162  oss << e << ", ";
163  }
164  std::cerr << std::string(indent, ' ') << std::boolalpha
165  << "CategoricalConditionNode {"
166  << "split_index: " << this->split_index << ", "
167  << "default_left: " << this->default_left << ", "
168  << "left_categories: [" << oss.str() << "], "
169  << "branch_hint: " << BranchHintName(this->branch_hint)
170  << "}" << std::endl;
171  }
172  std::vector<uint32_t> left_categories;
173 };
174 
175 class OutputNode : public ASTNode {
176  public:
177  OutputNode(tl_float scalar)
178  : is_vector(false), scalar(scalar) {}
179  OutputNode(const std::vector<tl_float>& vector)
180  : is_vector(true), vector(vector) {}
181  void Dump(int indent) override {
182  if (this->is_vector) {
183  std::ostringstream oss;
184  for (tl_float e : this->vector) {
185  oss << e << ", ";
186  }
187  std::cerr << std::string(indent, ' ')
188  << "OutputNode {vector: [" << oss.str() << "]}"
189  << std::endl;
190  } else {
191  std::cerr << std::string(indent, ' ')
192  << "OutputNode {scalar: " << this->scalar << "}"
193  << std::endl;
194  }
195  }
196 
197  bool is_vector;
198  tl_float scalar;
199  std::vector<tl_float> vector;
200 };
201 
202 } // namespace compiler
203 } // namespace treelite
204 
205 #endif // TREELITE_COMPILER_AST_AST_H_
float tl_float
float type to be used internally
Definition: base.h:17
Operator
comparison operators
Definition: base.h:23