7 #ifndef TREELITE_COMPILER_AST_AST_H_ 8 #define TREELITE_COMPILER_AST_AST_H_ 12 #include <fmt/format.h> 25 std::vector<ASTNode*> children;
30 virtual std::string GetDump()
const = 0;
33 ASTNode() : parent(
nullptr), node_id(-1), tree_id(-1) {}
35 inline ASTNode::~ASTNode() {}
41 : global_bias(global_bias), average_result(average_result),
42 num_tree(num_tree), num_feature(num_feature) {}
48 std::string GetDump()
const override {
49 return fmt::format(
"MainNode {{ global_bias: {}, average_result: {}, num_tree: {}, " 50 "num_feature: {} }}", global_bias, average_result, num_tree, num_feature);
59 std::string GetDump()
const override {
60 return fmt::format(
"TranslationUnitNode {{ unit_id: {} }}", unit_id);
64 template <
typename ThresholdType>
67 explicit QuantizerNode(
const std::vector<std::vector<ThresholdType>>& cut_pts)
69 explicit QuantizerNode(std::vector<std::vector<ThresholdType>>&& cut_pts)
70 : cut_pts(std::move(cut_pts)) {}
71 std::vector<std::vector<ThresholdType>> cut_pts;
73 std::string GetDump()
const override {
74 std::ostringstream oss;
75 for (
const auto& vec : cut_pts) {
77 for (
const auto& e : vec) {
82 return fmt::format(
"QuantizerNode {{ cut_pts: {} }}", oss.str());
90 std::string GetDump()
const override {
91 return fmt::format(
"AccumulatorContextNode {{}}");
99 std::string GetDump()
const override {
100 return fmt::format(
"CodeFolderNode {{}}");
107 : split_index(split_index), default_left(default_left) {}
108 unsigned split_index;
112 std::string GetDump()
const override {
114 return fmt::format(
"ConditionNode {{ split_index: {}, default_left: {}, gain: {} }}",
115 split_index, default_left, *gain);
117 return fmt::format(
"ConditionNode {{ split_index: {}, default_left: {} }}",
118 split_index, default_left);
123 template <
typename ThresholdType>
125 ThresholdType float_val;
131 template <
typename ThresholdType>
138 quantized(quantized), op(op), threshold(threshold), zero_quantized(-1) {}
144 std::string GetDump()
const override {
145 return fmt::format(
"NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {}, " 146 "zero_quantized: {} }}",
147 ConditionNode::GetDump(), quantized,
OpName(op),
148 (quantized ? fmt::format(
"{}", threshold.int_val)
149 : fmt::format(
"{}", threshold.float_val)),
157 const std::vector<uint32_t>& matching_categories,
158 bool categories_list_right_child)
160 matching_categories(matching_categories),
161 categories_list_right_child(categories_list_right_child) {}
162 std::vector<uint32_t> matching_categories;
163 bool categories_list_right_child;
165 std::string GetDump()
const override {
166 std::ostringstream oss;
168 for (
const auto& e : matching_categories) {
172 return fmt::format(
"CategoricalConditionNode {{ {}, matching_categories: {}, " 173 "categories_list_right_child: {} }}",
174 ConditionNode::GetDump(), oss.str(), categories_list_right_child);
178 template <
typename LeafOutputType>
182 : is_vector(
false), scalar(scalar) {}
183 explicit OutputNode(
const std::vector<LeafOutputType>& vector)
184 : is_vector(
true), vector(vector) {}
186 LeafOutputType scalar;
187 std::vector<LeafOutputType> vector;
189 std::string GetDump()
const override {
191 std::ostringstream oss;
193 for (
const auto& e : vector) {
197 return fmt::format(
"OutputNode {{ is_vector: {}, vector {} }}", is_vector, oss.str());
199 return fmt::format(
"OutputNode {{ is_vector: {}, scalar: {} }}", is_vector, scalar);
207 #endif // TREELITE_COMPILER_AST_AST_H_
Backport of std::optional from C++17.
float tl_float
float type to be used internally
std::string OpName(Operator op)
get string representation of comparison operator
defines configuration macros of Treelite
Operator
comparison operators