7 #ifndef TREELITE_COMPILER_AST_AST_H_ 8 #define TREELITE_COMPILER_AST_AST_H_ 10 #include <dmlc/optional.h> 12 #include <fmt/format.h> 24 std::vector<ASTNode*> children;
27 dmlc::optional<size_t> data_count;
28 dmlc::optional<double> sum_hess;
29 virtual std::string GetDump()
const = 0;
32 ASTNode() : parent(
nullptr), node_id(-1), tree_id(-1) {}
34 inline ASTNode::~ASTNode() {}
40 : global_bias(global_bias), average_result(average_result),
41 num_tree(num_tree), num_feature(num_feature) {}
47 std::string GetDump()
const override {
48 return fmt::format(
"MainNode {{ global_bias: {}, average_result: {}, num_tree: {}, " 49 "num_feature: {} }}", global_bias, average_result, num_tree, num_feature);
58 std::string GetDump()
const override {
59 return fmt::format(
"TranslationUnitNode {{ unit_id: {} }}", unit_id);
63 template <
typename ThresholdType>
66 explicit QuantizerNode(
const std::vector<std::vector<ThresholdType>>& cut_pts)
68 explicit QuantizerNode(std::vector<std::vector<ThresholdType>>&& cut_pts)
69 : cut_pts(std::move(cut_pts)) {}
70 std::vector<std::vector<ThresholdType>> cut_pts;
72 std::string GetDump()
const override {
73 std::ostringstream oss;
74 for (
const auto& vec : cut_pts) {
76 for (
const auto& e : vec) {
81 return fmt::format(
"QuantizerNode {{ cut_pts: {} }}", oss.str());
89 std::string GetDump()
const override {
90 return fmt::format(
"AccumulatorContextNode {{}}");
98 std::string GetDump()
const override {
99 return fmt::format(
"CodeFolderNode {{}}");
106 : split_index(split_index), default_left(default_left) {}
107 unsigned split_index;
109 dmlc::optional<double> gain;
111 std::string GetDump()
const override {
113 return fmt::format(
"ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
114 split_index, default_left, gain.value());
116 return fmt::format(
"ConditionNode {{ split_index: {}, default_left: {} }}",
117 split_index, default_left);
122 template <
typename ThresholdType>
124 ThresholdType float_val;
130 template <
typename ThresholdType>
137 quantized(quantized), op(op), threshold(threshold), zero_quantized(-1) {}
143 std::string GetDump()
const override {
144 return fmt::format(
"NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {}, " 145 "zero_quantized: {} }}",
146 ConditionNode::GetDump(), quantized,
OpName(op),
147 (quantized ? fmt::format(
"{}", threshold.int_val)
148 : fmt::format(
"{}", threshold.float_val)),
156 const std::vector<uint32_t>& matching_categories,
157 bool categories_list_right_child)
159 matching_categories(matching_categories),
160 categories_list_right_child(categories_list_right_child) {}
161 std::vector<uint32_t> matching_categories;
162 bool categories_list_right_child;
164 std::string GetDump()
const override {
165 std::ostringstream oss;
167 for (
const auto& e : matching_categories) {
171 return fmt::format(
"CategoricalConditionNode {{ {}, matching_categories: {}, " 172 "categories_list_right_child: {} }}",
173 ConditionNode::GetDump(), oss.str(), categories_list_right_child);
177 template <
typename LeafOutputType>
181 : is_vector(
false), scalar(scalar) {}
182 explicit OutputNode(
const std::vector<LeafOutputType>& vector)
183 : is_vector(
true), vector(vector) {}
185 LeafOutputType scalar;
186 std::vector<LeafOutputType> vector;
188 std::string GetDump()
const override {
190 std::ostringstream oss;
192 for (
const auto& e : vector) {
196 return fmt::format(
"OutputNode {{ is_vector: {}, vector {} }}", is_vector, oss.str());
198 return fmt::format(
"OutputNode {{ is_vector: {}, scalar: {} }}", is_vector, scalar);
206 #endif // TREELITE_COMPILER_AST_AST_H_
float tl_float
float type to be used internally
std::string OpName(Operator op)
get string representation of comparison operator
defines configuration macros of Treelite
Operator
comparison operators