treelite
ast.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_AST_AST_H_
8 #define TREELITE_COMPILER_AST_AST_H_
9 
10 #include <dmlc/optional.h>
11 #include <treelite/base.h>
12 #include <fmt/format.h>
13 #include <limits>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 
18 // forward declaration
19 namespace treelite_ast_protobuf {
20 class ASTNode;
21 } // namespace treelite_ast_protobuf
22 
23 namespace treelite {
24 namespace compiler {
25 
26 class ASTNode {
27  public:
28  ASTNode* parent;
29  std::vector<ASTNode*> children;
30  int node_id;
31  int tree_id;
32  dmlc::optional<size_t> data_count;
33  dmlc::optional<double> sum_hess;
34  virtual std::string GetDump() const = 0;
35  virtual ~ASTNode() = 0; // force ASTNode to be abstract class
36  protected:
37  ASTNode() : parent(nullptr), node_id(-1), tree_id(-1) {}
38 };
39 inline ASTNode::~ASTNode() {}
40 
41 class MainNode : public ASTNode {
42  public:
43  MainNode(tl_float global_bias, bool average_result, int num_tree,
44  int num_feature)
45  : global_bias(global_bias), average_result(average_result),
46  num_tree(num_tree), num_feature(num_feature) {}
47  tl_float global_bias;
48  bool average_result;
49  int num_tree;
50  int num_feature;
51 
52  std::string GetDump() const override {
53  return fmt::format("MainNode {{ global_bias: {}, average_result: {}, num_tree: {}, "
54  "num_feature: {} }}", global_bias, average_result, num_tree, num_feature);
55  }
56 };
57 
58 class TranslationUnitNode : public ASTNode {
59  public:
60  explicit TranslationUnitNode(int unit_id) : unit_id(unit_id) {}
61  int unit_id;
62 
63  std::string GetDump() const override {
64  return fmt::format("TranslationUnitNode {{ unit_id: {} }}", unit_id);
65  }
66 };
67 
68 class QuantizerNode : public ASTNode {
69  public:
70  explicit QuantizerNode(const std::vector<std::vector<tl_float>>& cut_pts)
71  : cut_pts(cut_pts) {}
72  explicit QuantizerNode(std::vector<std::vector<tl_float>>&& cut_pts)
73  : cut_pts(std::move(cut_pts)) {}
74  std::vector<std::vector<tl_float>> cut_pts;
75 
76  std::string GetDump() const override {
77  std::ostringstream oss;
78  for (const auto& vec : cut_pts) {
79  oss << "[ ";
80  for (const auto& e : vec) {
81  oss << e << ", ";
82  }
83  oss << "], ";
84  }
85  return fmt::format("QuantizerNode {{ cut_pts: {} }}", oss.str());
86  }
87 };
88 
90  public:
92 
93  std::string GetDump() const override {
94  return fmt::format("AccumulatorContextNode {{}}");
95  }
96 };
97 
98 class CodeFolderNode : public ASTNode {
99  public:
100  CodeFolderNode() {}
101 
102  std::string GetDump() const override {
103  return fmt::format("CodeFolderNode {{}}");
104  }
105 };
106 
107 class ConditionNode : public ASTNode {
108  public:
109  ConditionNode(unsigned split_index, bool default_left)
110  : split_index(split_index), default_left(default_left) {}
111  unsigned split_index;
112  bool default_left;
113  dmlc::optional<double> gain;
114 
115  std::string GetDump() const override {
116  if (gain) {
117  return fmt::format("ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
118  split_index, default_left, gain.value());
119  } else {
120  return fmt::format("ConditionNode {{ split_index: {}, default_left: {} }}",
121  split_index, default_left);
122  }
123  }
124 };
125 
127  tl_float float_val;
128  int int_val;
129  ThresholdVariant(tl_float val) : float_val(val) {}
130  ThresholdVariant(int val) : int_val(val) {}
131 };
132 
134  public:
135  NumericalConditionNode(unsigned split_index, bool default_left,
136  bool quantized, Operator op,
137  ThresholdVariant threshold)
138  : ConditionNode(split_index, default_left),
139  quantized(quantized), op(op), threshold(threshold) {}
140  bool quantized;
141  Operator op;
142  ThresholdVariant threshold;
143 
144  std::string GetDump() const override {
145  return fmt::format("NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {} }}",
146  ConditionNode::GetDump(), quantized, OpName(op),
147  (quantized ? fmt::format("{:d}", threshold.int_val)
148  : fmt::format("{:f}", threshold.float_val)));
149  }
150 };
151 
153  public:
154  CategoricalConditionNode(unsigned split_index, bool default_left,
155  const std::vector<uint32_t>& left_categories,
156  bool convert_missing_to_zero)
157  : ConditionNode(split_index, default_left),
158  left_categories(left_categories),
159  convert_missing_to_zero(convert_missing_to_zero) {}
160  std::vector<uint32_t> left_categories;
161  bool convert_missing_to_zero;
162 
163  std::string GetDump() const override {
164  std::ostringstream oss;
165  oss << "[";
166  for (const auto& e : left_categories) {
167  oss << e << ", ";
168  }
169  oss << "]";
170  return fmt::format("CategoricalConditionNode {{ {}, left_categories: {}, "
171  "convert_missing_to_zero: {} }}",
172  ConditionNode::GetDump(), oss.str(), convert_missing_to_zero);
173  }
174 };
175 
176 class OutputNode : public ASTNode {
177  public:
178  explicit OutputNode(tl_float scalar)
179  : is_vector(false), scalar(scalar) {}
180  explicit OutputNode(const std::vector<tl_float>& vector)
181  : is_vector(true), vector(vector) {}
182  bool is_vector;
183  tl_float scalar;
184  std::vector<tl_float> vector;
185 
186  std::string GetDump() const override {
187  if (is_vector) {
188  std::ostringstream oss;
189  oss << "[";
190  for (const auto& e : vector) {
191  oss << e << ", ";
192  }
193  oss << "]";
194  return fmt::format("OutputNode {{ is_vector: {}, vector {} }}", is_vector, oss.str());
195  } else {
196  return fmt::format("OutputNode {{ is_vector: {}, scalar: {} }}", is_vector, scalar);
197  }
198  }
199 };
200 
201 } // namespace compiler
202 } // namespace treelite
203 
204 #endif // TREELITE_COMPILER_AST_AST_H_
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
Operator
comparison operators
Definition: base.h:23