treelite
ast.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_AST_AST_H_
8 #define TREELITE_COMPILER_AST_AST_H_
9 
10 #include <dmlc/optional.h>
11 #include <treelite/base.h>
12 #include <string>
13 #include <vector>
14 #include <utility>
15 
16 // forward declaration
17 namespace treelite_ast_protobuf {
18 class ASTNode;
19 } // namespace treelite_ast_protobuf
20 
21 namespace treelite {
22 namespace compiler {
23 
24 class ASTNode {
25  public:
26  ASTNode* parent;
27  std::vector<ASTNode*> children;
28  int node_id;
29  int tree_id;
30  dmlc::optional<size_t> data_count;
31  dmlc::optional<double> sum_hess;
32  virtual ~ASTNode() = 0; // force ASTNode to be abstract class
33  protected:
34  ASTNode() : parent(nullptr), node_id(-1), tree_id(-1) {}
35 };
36 inline ASTNode::~ASTNode() {}
37 
38 class MainNode : public ASTNode {
39  public:
40  MainNode(tl_float global_bias, bool average_result, int num_tree,
41  int num_feature)
42  : global_bias(global_bias), average_result(average_result),
43  num_tree(num_tree), num_feature(num_feature) {}
44  tl_float global_bias;
45  bool average_result;
46  int num_tree;
47  int num_feature;
48 };
49 
50 class TranslationUnitNode : public ASTNode {
51  public:
52  explicit TranslationUnitNode(int unit_id) : unit_id(unit_id) {}
53  int unit_id;
54 };
55 
56 class QuantizerNode : public ASTNode {
57  public:
58  explicit QuantizerNode(const std::vector<std::vector<tl_float>>& cut_pts)
59  : cut_pts(cut_pts) {}
60  explicit QuantizerNode(std::vector<std::vector<tl_float>>&& cut_pts)
61  : cut_pts(std::move(cut_pts)) {}
62  std::vector<std::vector<tl_float>> cut_pts;
63 };
64 
66  public:
68 };
69 
70 class CodeFolderNode : public ASTNode {
71  public:
72  CodeFolderNode() {}
73 };
74 
75 class ConditionNode : public ASTNode {
76  public:
77  ConditionNode(unsigned split_index, bool default_left)
78  : split_index(split_index), default_left(default_left) {}
79  unsigned split_index;
80  bool default_left;
81  dmlc::optional<double> gain;
82 };
83 
85  tl_float float_val;
86  int int_val;
87  ThresholdVariant(tl_float val) : float_val(val) {}
88  ThresholdVariant(int val) : int_val(val) {}
89 };
90 
92  public:
93  NumericalConditionNode(unsigned split_index, bool default_left,
94  bool quantized, Operator op,
95  ThresholdVariant threshold)
96  : ConditionNode(split_index, default_left),
97  quantized(quantized), op(op), threshold(threshold) {}
98  bool quantized;
99  Operator op;
100  ThresholdVariant threshold;
101 };
102 
104  public:
105  CategoricalConditionNode(unsigned split_index, bool default_left,
106  const std::vector<uint32_t>& left_categories,
107  bool convert_missing_to_zero)
108  : ConditionNode(split_index, default_left),
109  left_categories(left_categories),
110  convert_missing_to_zero(convert_missing_to_zero) {}
111  std::vector<uint32_t> left_categories;
112  bool convert_missing_to_zero;
113 };
114 
115 class OutputNode : public ASTNode {
116  public:
117  explicit OutputNode(tl_float scalar)
118  : is_vector(false), scalar(scalar) {}
119  explicit OutputNode(const std::vector<tl_float>& vector)
120  : is_vector(true), vector(vector) {}
121  bool is_vector;
122  tl_float scalar;
123  std::vector<tl_float> vector;
124 };
125 
126 } // namespace compiler
127 } // namespace treelite
128 
129 #endif // TREELITE_COMPILER_AST_AST_H_
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
Operator
comparison operators
Definition: base.h:23