7 #ifndef TREELITE_COMPILER_AST_AST_H_ 8 #define TREELITE_COMPILER_AST_AST_H_ 14 #include <dmlc/optional.h> 16 #include <fmt/format.h> 29 std::vector<ASTNode*> children;
32 dmlc::optional<size_t> data_count;
33 dmlc::optional<double> sum_hess;
34 virtual std::string GetDump()
const = 0;
37 ASTNode() : parent(
nullptr), node_id(-1), tree_id(-1) {}
39 inline ASTNode::~ASTNode() {}
45 : global_bias(global_bias), average_result(average_result),
46 num_tree(num_tree), num_feature(num_feature) {}
52 std::string GetDump()
const override {
53 return fmt::format(
"MainNode {{ global_bias: {}, average_result: {}, num_tree: {}, " 54 "num_feature: {} }}", global_bias, average_result, num_tree, num_feature);
63 std::string GetDump()
const override {
64 return fmt::format(
"TranslationUnitNode {{ unit_id: {} }}", unit_id);
70 explicit QuantizerNode(
const std::vector<std::vector<tl_float>>& cut_pts)
72 explicit QuantizerNode(std::vector<std::vector<tl_float>>&& cut_pts)
73 : cut_pts(std::move(cut_pts)) {}
74 std::vector<std::vector<tl_float>> cut_pts;
76 std::string GetDump()
const override {
77 std::ostringstream oss;
78 for (
const auto& vec : cut_pts) {
80 for (
const auto& e : vec) {
85 return fmt::format(
"QuantizerNode {{ cut_pts: {} }}", oss.str());
93 std::string GetDump()
const override {
94 return fmt::format(
"AccumulatorContextNode {{}}");
102 std::string GetDump()
const override {
103 return fmt::format(
"CodeFolderNode {{}}");
110 : split_index(split_index), default_left(default_left) {}
111 unsigned split_index;
113 dmlc::optional<double> gain;
115 std::string GetDump()
const override {
117 return fmt::format(
"ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
118 split_index, default_left, gain.value());
120 return fmt::format(
"ConditionNode {{ split_index: {}, default_left: {} }}",
121 split_index, default_left);
139 quantized(quantized), op(op), threshold(threshold) {}
144 std::string GetDump()
const override {
145 return fmt::format(
"NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {} }}",
146 ConditionNode::GetDump(), quantized,
OpName(op),
147 (quantized ? fmt::format(
"{:d}", threshold.int_val)
148 : fmt::format(
"{:f}", threshold.float_val)));
155 const std::vector<uint32_t>& left_categories,
156 bool convert_missing_to_zero)
158 left_categories(left_categories),
159 convert_missing_to_zero(convert_missing_to_zero) {}
160 std::vector<uint32_t> left_categories;
161 bool convert_missing_to_zero;
163 std::string GetDump()
const override {
164 std::ostringstream oss;
166 for (
const auto& e : left_categories) {
170 return fmt::format(
"CategoricalConditionNode {{ {}, left_categories: {}, " 171 "convert_missing_to_zero: {} }}",
172 ConditionNode::GetDump(), oss.str(), convert_missing_to_zero);
179 : is_vector(
false), scalar(scalar) {}
180 explicit OutputNode(
const std::vector<tl_float>& vector)
181 : is_vector(
true), vector(vector) {}
184 std::vector<tl_float> vector;
186 std::string GetDump()
const override {
188 std::ostringstream oss;
190 for (
const auto& e : vector) {
194 return fmt::format(
"OutputNode {{ is_vector: {}, vector {} }}", is_vector, oss.str());
196 return fmt::format(
"OutputNode {{ is_vector: {}, scalar: {} }}", is_vector, scalar);
204 #endif // TREELITE_COMPILER_AST_AST_H_
std::string OpName(Operator op)
get string representation of comparsion operator
double tl_float
float type to be used internally
defines configuration macros of treelite
Operator
comparison operators