1 #ifndef TREELITE_COMPILER_AST_AST_H_ 2 #define TREELITE_COMPILER_AST_AST_H_ 12 inline std::string OpName(Operator op) {
14 case Operator::kEQ:
return "==";
15 case Operator::kLT:
return "<";
16 case Operator::kLE:
return "<=";
17 case Operator::kGT:
return ">";
18 case Operator::kGE:
return ">=";
24 enum class BranchHint : uint8_t {
30 inline std::string BranchHintName(BranchHint hint) {
32 case BranchHint::kNone:
34 case BranchHint::kLikely:
36 case BranchHint::kUnlikely:
39 LOG(FATAL) <<
"Unrecognized BranchHint value";
45 std::vector<ASTNode*> children;
49 virtual void Dump(
int indent) = 0;
56 : global_bias(global_bias), average_result(average_result),
57 num_tree(num_tree), num_feature(num_feature) {}
58 void Dump(
int indent)
override {
59 std::cerr << std::string(indent,
' ') << std::boolalpha
61 <<
"global_bias: " << this->global_bias <<
", " 62 <<
"average_result: " << this->average_result <<
", " 63 <<
"num_tree: " << this->num_tree <<
", " 64 <<
"num_feature: " << this->num_feature <<
"}" 76 void Dump(
int indent)
override {
77 std::cerr << std::string(indent,
' ')
78 <<
"TranslationUnitNode {" 79 <<
"unit_id: " << unit_id <<
"}" 87 QuantizerNode(
const std::vector<std::vector<tl_float>>& cut_pts,
88 const std::vector<bool>& is_categorical)
89 : cut_pts(cut_pts), is_categorical(is_categorical) {}
91 std::vector<bool>&& is_categorical)
92 : cut_pts(std::move(cut_pts)), is_categorical(std::move(is_categorical)) {}
93 std::vector<std::vector<tl_float>> cut_pts;
94 std::vector<bool> is_categorical;
95 void Dump(
int indent)
override {
96 std::cerr << std::string(indent,
' ')
97 <<
"QuantizerNode = {}" << std::endl;
104 void Dump(
int indent)
override {
105 std::cerr << std::string(indent,
' ')
106 <<
"AccumulatorContextNode = {}" << std::endl;
112 ConditionNode(
unsigned split_index,
bool default_left, BranchHint branch_hint)
113 : split_index(split_index), default_left(default_left),
114 branch_hint(branch_hint) {}
115 unsigned split_index;
117 BranchHint branch_hint;
132 BranchHint branch_hint = BranchHint::kNone)
134 quantized(quantized), op(op), threshold(threshold) {}
135 void Dump(
int indent)
override {
136 std::cerr << std::string(indent,
' ') << std::boolalpha
137 <<
"NumericalConditionNode {" 138 <<
"split_index: " << this->split_index <<
", " 139 <<
"default_left: " << this->default_left <<
", " 140 <<
"quantized: " << this->quantized <<
", " 141 <<
"op: " << OpName(this->op) <<
", " 142 <<
"threshold: " << (quantized ? this->threshold.int_val
143 : this->threshold.float_val) <<
", " 144 <<
"branch_hint: " << BranchHintName(this->branch_hint)
155 const std::vector<uint32_t>& left_categories,
156 BranchHint branch_hint = BranchHint::kNone)
158 left_categories(left_categories) {}
159 void Dump(
int indent)
override {
160 std::ostringstream oss;
161 for (uint32_t e : this->left_categories) {
164 std::cerr << std::string(indent,
' ') << std::boolalpha
165 <<
"CategoricalConditionNode {" 166 <<
"split_index: " << this->split_index <<
", " 167 <<
"default_left: " << this->default_left <<
", " 168 <<
"left_categories: [" << oss.str() <<
"], " 169 <<
"branch_hint: " << BranchHintName(this->branch_hint)
172 std::vector<uint32_t> left_categories;
178 : is_vector(
false), scalar(scalar) {}
179 OutputNode(
const std::vector<tl_float>& vector)
180 : is_vector(
true), vector(vector) {}
181 void Dump(
int indent)
override {
182 if (this->is_vector) {
183 std::ostringstream oss;
187 std::cerr << std::string(indent,
' ')
188 <<
"OutputNode {vector: [" << oss.str() <<
"]}" 191 std::cerr << std::string(indent,
' ')
192 <<
"OutputNode {scalar: " << this->scalar <<
"}" 199 std::vector<tl_float> vector;
205 #endif // TREELITE_COMPILER_AST_AST_H_
float tl_float
float type to be used internally
Operator
comparison operators