Treelite
ast.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_AST_AST_H_
8 #define TREELITE_COMPILER_AST_AST_H_
9 
10 #include <treelite/optional.h>
11 #include <treelite/base.h>
12 #include <fmt/format.h>
13 #include <limits>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 #include <cstdint>
18 
19 namespace treelite {
20 namespace compiler {
21 
22 class ASTNode {
23  public:
24  ASTNode* parent;
25  std::vector<ASTNode*> children;
26  int node_id;
27  int tree_id;
28  optional<uint64_t> data_count;
29  optional<double> sum_hess;
30  virtual std::string GetDump() const = 0;
31  virtual ~ASTNode() = 0; // force ASTNode to be abstract class
32  protected:
33  ASTNode() : parent(nullptr), node_id(-1), tree_id(-1) {}
34 };
35 inline ASTNode::~ASTNode() {}
36 
37 class MainNode : public ASTNode {
38  public:
39  MainNode(tl_float global_bias, bool average_result, int num_tree,
40  int num_feature)
41  : global_bias(global_bias), average_result(average_result),
42  num_tree(num_tree), num_feature(num_feature) {}
43  tl_float global_bias;
44  bool average_result;
45  int num_tree;
46  int num_feature;
47 
48  std::string GetDump() const override {
49  return fmt::format("MainNode {{ global_bias: {}, average_result: {}, num_tree: {}, "
50  "num_feature: {} }}", global_bias, average_result, num_tree, num_feature);
51  }
52 };
53 
54 class TranslationUnitNode : public ASTNode {
55  public:
56  explicit TranslationUnitNode(int unit_id) : unit_id(unit_id) {}
57  int unit_id;
58 
59  std::string GetDump() const override {
60  return fmt::format("TranslationUnitNode {{ unit_id: {} }}", unit_id);
61  }
62 };
63 
64 template <typename ThresholdType>
65 class QuantizerNode : public ASTNode {
66  public:
67  explicit QuantizerNode(const std::vector<std::vector<ThresholdType>>& cut_pts)
68  : cut_pts(cut_pts) {}
69  explicit QuantizerNode(std::vector<std::vector<ThresholdType>>&& cut_pts)
70  : cut_pts(std::move(cut_pts)) {}
71  std::vector<std::vector<ThresholdType>> cut_pts;
72 
73  std::string GetDump() const override {
74  std::ostringstream oss;
75  for (const auto& vec : cut_pts) {
76  oss << "[ ";
77  for (const auto& e : vec) {
78  oss << e << ", ";
79  }
80  oss << "], ";
81  }
82  return fmt::format("QuantizerNode {{ cut_pts: {} }}", oss.str());
83  }
84 };
85 
87  public:
89 
90  std::string GetDump() const override {
91  return fmt::format("AccumulatorContextNode {{}}");
92  }
93 };
94 
95 class CodeFolderNode : public ASTNode {
96  public:
97  CodeFolderNode() {}
98 
99  std::string GetDump() const override {
100  return fmt::format("CodeFolderNode {{}}");
101  }
102 };
103 
104 class ConditionNode : public ASTNode {
105  public:
106  ConditionNode(unsigned split_index, bool default_left)
107  : split_index(split_index), default_left(default_left) {}
108  unsigned split_index;
109  bool default_left;
110  optional<double> gain;
111 
112  std::string GetDump() const override {
113  if (gain) {
114  return fmt::format("ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
115  split_index, default_left, *gain);
116  } else {
117  return fmt::format("ConditionNode {{ split_index: {}, default_left: {} }}",
118  split_index, default_left);
119  }
120  }
121 };
122 
123 template <typename ThresholdType>
125  ThresholdType float_val;
126  int int_val;
127  explicit ThresholdVariant(ThresholdType val) : float_val(val) {}
128  explicit ThresholdVariant(int val) : int_val(val) {}
129 };
130 
131 template <typename ThresholdType>
133  public:
134  NumericalConditionNode(unsigned split_index, bool default_left,
135  bool quantized, Operator op,
137  : ConditionNode(split_index, default_left),
138  quantized(quantized), op(op), threshold(threshold), zero_quantized(-1) {}
139  bool quantized;
140  Operator op;
142  int zero_quantized; // quantized value of 0.0f (useful when convert_missing_to_zero is set)
143 
144  std::string GetDump() const override {
145  return fmt::format("NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {}, "
146  "zero_quantized: {} }}",
147  ConditionNode::GetDump(), quantized, OpName(op),
148  (quantized ? fmt::format("{}", threshold.int_val)
149  : fmt::format("{}", threshold.float_val)),
150  zero_quantized);
151  }
152 };
153 
155  public:
156  CategoricalConditionNode(unsigned split_index, bool default_left,
157  const std::vector<uint32_t>& matching_categories,
158  bool categories_list_right_child)
159  : ConditionNode(split_index, default_left),
160  matching_categories(matching_categories),
161  categories_list_right_child(categories_list_right_child) {}
162  std::vector<uint32_t> matching_categories;
163  bool categories_list_right_child;
164 
165  std::string GetDump() const override {
166  std::ostringstream oss;
167  oss << "[";
168  for (const auto& e : matching_categories) {
169  oss << e << ", ";
170  }
171  oss << "]";
172  return fmt::format("CategoricalConditionNode {{ {}, matching_categories: {}, "
173  "categories_list_right_child: {} }}",
174  ConditionNode::GetDump(), oss.str(), categories_list_right_child);
175  }
176 };
177 
178 template <typename LeafOutputType>
179 class OutputNode : public ASTNode {
180  public:
181  explicit OutputNode(LeafOutputType scalar)
182  : is_vector(false), scalar(scalar) {}
183  explicit OutputNode(const std::vector<LeafOutputType>& vector)
184  : is_vector(true), vector(vector) {}
185  bool is_vector;
186  LeafOutputType scalar;
187  std::vector<LeafOutputType> vector;
188 
189  std::string GetDump() const override {
190  if (is_vector) {
191  std::ostringstream oss;
192  oss << "[";
193  for (const auto& e : vector) {
194  oss << e << ", ";
195  }
196  oss << "]";
197  return fmt::format("OutputNode {{ is_vector: {}, vector {} }}", is_vector, oss.str());
198  } else {
199  return fmt::format("OutputNode {{ is_vector: {}, scalar: {} }}", is_vector, scalar);
200  }
201  }
202 };
203 
204 } // namespace compiler
205 } // namespace treelite
206 
207 #endif // TREELITE_COMPILER_AST_AST_H_
Backport of std::optional from C++17.
float tl_float
float type to be used internally
Definition: base.h:20
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:43
defines configuration macros of Treelite
Operator
comparison operators
Definition: base.h:26