Treelite
ast.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_AST_AST_H_
8 #define TREELITE_COMPILER_AST_AST_H_
9 
10 #include <dmlc/optional.h>
11 #include <treelite/base.h>
12 #include <fmt/format.h>
13 #include <limits>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 
18 namespace treelite {
19 namespace compiler {
20 
21 class ASTNode {
22  public:
23  ASTNode* parent;
24  std::vector<ASTNode*> children;
25  int node_id;
26  int tree_id;
27  dmlc::optional<size_t> data_count;
28  dmlc::optional<double> sum_hess;
29  virtual std::string GetDump() const = 0;
30  virtual ~ASTNode() = 0; // force ASTNode to be abstract class
31  protected:
32  ASTNode() : parent(nullptr), node_id(-1), tree_id(-1) {}
33 };
34 inline ASTNode::~ASTNode() {}
35 
36 class MainNode : public ASTNode {
37  public:
38  MainNode(tl_float global_bias, bool average_result, int num_tree,
39  int num_feature)
40  : global_bias(global_bias), average_result(average_result),
41  num_tree(num_tree), num_feature(num_feature) {}
42  tl_float global_bias;
43  bool average_result;
44  int num_tree;
45  int num_feature;
46 
47  std::string GetDump() const override {
48  return fmt::format("MainNode {{ global_bias: {}, average_result: {}, num_tree: {}, "
49  "num_feature: {} }}", global_bias, average_result, num_tree, num_feature);
50  }
51 };
52 
53 class TranslationUnitNode : public ASTNode {
54  public:
55  explicit TranslationUnitNode(int unit_id) : unit_id(unit_id) {}
56  int unit_id;
57 
58  std::string GetDump() const override {
59  return fmt::format("TranslationUnitNode {{ unit_id: {} }}", unit_id);
60  }
61 };
62 
63 template <typename ThresholdType>
64 class QuantizerNode : public ASTNode {
65  public:
66  explicit QuantizerNode(const std::vector<std::vector<ThresholdType>>& cut_pts)
67  : cut_pts(cut_pts) {}
68  explicit QuantizerNode(std::vector<std::vector<ThresholdType>>&& cut_pts)
69  : cut_pts(std::move(cut_pts)) {}
70  std::vector<std::vector<ThresholdType>> cut_pts;
71 
72  std::string GetDump() const override {
73  std::ostringstream oss;
74  for (const auto& vec : cut_pts) {
75  oss << "[ ";
76  for (const auto& e : vec) {
77  oss << e << ", ";
78  }
79  oss << "], ";
80  }
81  return fmt::format("QuantizerNode {{ cut_pts: {} }}", oss.str());
82  }
83 };
84 
86  public:
88 
89  std::string GetDump() const override {
90  return fmt::format("AccumulatorContextNode {{}}");
91  }
92 };
93 
94 class CodeFolderNode : public ASTNode {
95  public:
96  CodeFolderNode() {}
97 
98  std::string GetDump() const override {
99  return fmt::format("CodeFolderNode {{}}");
100  }
101 };
102 
103 class ConditionNode : public ASTNode {
104  public:
105  ConditionNode(unsigned split_index, bool default_left)
106  : split_index(split_index), default_left(default_left) {}
107  unsigned split_index;
108  bool default_left;
109  dmlc::optional<double> gain;
110 
111  std::string GetDump() const override {
112  if (gain) {
113  return fmt::format("ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
114  split_index, default_left, gain.value());
115  } else {
116  return fmt::format("ConditionNode {{ split_index: {}, default_left: {} }}",
117  split_index, default_left);
118  }
119  }
120 };
121 
122 template <typename ThresholdType>
124  ThresholdType float_val;
125  int int_val;
126  explicit ThresholdVariant(ThresholdType val) : float_val(val) {}
127  explicit ThresholdVariant(int val) : int_val(val) {}
128 };
129 
130 template <typename ThresholdType>
132  public:
133  NumericalConditionNode(unsigned split_index, bool default_left,
134  bool quantized, Operator op,
136  : ConditionNode(split_index, default_left),
137  quantized(quantized), op(op), threshold(threshold), zero_quantized(-1) {}
138  bool quantized;
139  Operator op;
141  int zero_quantized; // quantized value of 0.0f (useful when convert_missing_to_zero is set)
142 
143  std::string GetDump() const override {
144  return fmt::format("NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {}, "
145  "zero_quantized: {} }}",
146  ConditionNode::GetDump(), quantized, OpName(op),
147  (quantized ? fmt::format("{}", threshold.int_val)
148  : fmt::format("{}", threshold.float_val)),
149  zero_quantized);
150  }
151 };
152 
154  public:
155  CategoricalConditionNode(unsigned split_index, bool default_left,
156  const std::vector<uint32_t>& matching_categories,
157  bool categories_list_right_child)
158  : ConditionNode(split_index, default_left),
159  matching_categories(matching_categories),
160  categories_list_right_child(categories_list_right_child) {}
161  std::vector<uint32_t> matching_categories;
162  bool categories_list_right_child;
163 
164  std::string GetDump() const override {
165  std::ostringstream oss;
166  oss << "[";
167  for (const auto& e : matching_categories) {
168  oss << e << ", ";
169  }
170  oss << "]";
171  return fmt::format("CategoricalConditionNode {{ {}, matching_categories: {}, "
172  "categories_list_right_child: {} }}",
173  ConditionNode::GetDump(), oss.str(), categories_list_right_child);
174  }
175 };
176 
177 template <typename LeafOutputType>
178 class OutputNode : public ASTNode {
179  public:
180  explicit OutputNode(LeafOutputType scalar)
181  : is_vector(false), scalar(scalar) {}
182  explicit OutputNode(const std::vector<LeafOutputType>& vector)
183  : is_vector(true), vector(vector) {}
184  bool is_vector;
185  LeafOutputType scalar;
186  std::vector<LeafOutputType> vector;
187 
188  std::string GetDump() const override {
189  if (is_vector) {
190  std::ostringstream oss;
191  oss << "[";
192  for (const auto& e : vector) {
193  oss << e << ", ";
194  }
195  oss << "]";
196  return fmt::format("OutputNode {{ is_vector: {}, vector {} }}", is_vector, oss.str());
197  } else {
198  return fmt::format("OutputNode {{ is_vector: {}, scalar: {} }}", is_vector, scalar);
199  }
200  }
201 };
202 
203 } // namespace compiler
204 } // namespace treelite
205 
206 #endif // TREELITE_COMPILER_AST_AST_H_
float tl_float
float type to be used internally
Definition: base.h:20
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:43
defines configuration macros of Treelite
Operator
comparison operators
Definition: base.h:26