treelite
ast.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_AST_AST_H_
8 #define TREELITE_COMPILER_AST_AST_H_
9 
10 #include <dmlc/optional.h>
11 #include <treelite/base.h>
12 #include <string>
13 #include <vector>
14 
15 // forward declaration
16 namespace treelite_ast_protobuf {
17 class ASTNode;
18 } // namespace treelite_ast_protobuf
19 
20 namespace treelite {
21 namespace compiler {
22 
23 class ASTNode {
24  public:
25  ASTNode* parent;
26  std::vector<ASTNode*> children;
27  int node_id;
28  int tree_id;
29  dmlc::optional<size_t> data_count;
30  dmlc::optional<double> sum_hess;
31  virtual ~ASTNode() = 0; // force ASTNode to be abstract class
32  protected:
33  ASTNode() : parent(nullptr), node_id(-1), tree_id(-1) {}
34 };
35 inline ASTNode::~ASTNode() {}
36 
37 class MainNode : public ASTNode {
38  public:
39  MainNode(tl_float global_bias, bool average_result, int num_tree,
40  int num_feature)
41  : global_bias(global_bias), average_result(average_result),
42  num_tree(num_tree), num_feature(num_feature) {}
43  tl_float global_bias;
44  bool average_result;
45  int num_tree;
46  int num_feature;
47 };
48 
49 class TranslationUnitNode : public ASTNode {
50  public:
51  explicit TranslationUnitNode(int unit_id) : unit_id(unit_id) {}
52  int unit_id;
53 };
54 
55 class QuantizerNode : public ASTNode {
56  public:
57  explicit QuantizerNode(const std::vector<std::vector<tl_float>>& cut_pts)
58  : cut_pts(cut_pts) {}
59  explicit QuantizerNode(std::vector<std::vector<tl_float>>&& cut_pts)
60  : cut_pts(std::move(cut_pts)) {}
61  std::vector<std::vector<tl_float>> cut_pts;
62 };
63 
65  public:
67 };
68 
69 class CodeFolderNode : public ASTNode {
70  public:
71  CodeFolderNode() {}
72 };
73 
74 class ConditionNode : public ASTNode {
75  public:
76  ConditionNode(unsigned split_index, bool default_left)
77  : split_index(split_index), default_left(default_left) {}
78  unsigned split_index;
79  bool default_left;
80  dmlc::optional<double> gain;
81 };
82 
84  tl_float float_val;
85  int int_val;
86  ThresholdVariant(tl_float val) : float_val(val) {}
87  ThresholdVariant(int val) : int_val(val) {}
88 };
89 
91  public:
92  NumericalConditionNode(unsigned split_index, bool default_left,
93  bool quantized, Operator op,
94  ThresholdVariant threshold)
95  : ConditionNode(split_index, default_left),
96  quantized(quantized), op(op), threshold(threshold) {}
97  bool quantized;
98  Operator op;
99  ThresholdVariant threshold;
100 };
101 
103  public:
104  CategoricalConditionNode(unsigned split_index, bool default_left,
105  const std::vector<uint32_t>& left_categories)
106  : ConditionNode(split_index, default_left),
107  left_categories(left_categories) {}
108  std::vector<uint32_t> left_categories;
109 };
110 
111 class OutputNode : public ASTNode {
112  public:
113  explicit OutputNode(tl_float scalar)
114  : is_vector(false), scalar(scalar) {}
115  explicit OutputNode(const std::vector<tl_float>& vector)
116  : is_vector(true), vector(vector) {}
117  bool is_vector;
118  tl_float scalar;
119  std::vector<tl_float> vector;
120 };
121 
122 } // namespace compiler
123 } // namespace treelite
124 
125 #endif // TREELITE_COMPILER_AST_AST_H_
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
Operator
comparison operators
Definition: base.h:23