7 #ifndef TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_ 8 #define TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_ 14 #include <unordered_map> 15 #include <dmlc/logging.h> 16 #include <treelite/common.h> 17 #include <fmt/format.h> 18 #include "../ast/ast.h" 25 namespace common_util {
27 template <
typename OutputFormatFunc>
29 RenderCodeFolderArrays(
const CodeFolderNode* node,
31 bool use_boolean_literal,
32 const char* node_entry_template,
33 OutputFormatFunc RenderOutputStatement,
34 std::string* array_nodes,
35 std::string* array_cat_bitmap,
36 std::string* array_cat_begin,
37 std::string* output_switch_statements,
38 Operator* common_comp_op) {
39 CHECK_EQ(node->children.size(), 1);
40 const int tree_id = node->children[0]->tree_id;
42 std::unordered_map<ASTNode*, int> descendants;
44 std::vector<OutputNode*> output_nodes;
46 std::vector<uint64_t> cat_bitmap;
47 std::vector<size_t> cat_begin{0};
52 std::queue<ASTNode*> Q;
53 std::set<treelite::Operator> ops;
56 Q.push(node->children[0]);
58 ASTNode* e = Q.front(); Q.pop();
60 CHECK_EQ(e->tree_id, tree_id);
62 ConditionNode* t1 =
dynamic_cast<ConditionNode*
>(e);
63 OutputNode* t2 =
dynamic_cast<OutputNode*
>(e);
64 NumericalConditionNode* t3;
67 descendants[e] = new_leaf_id--;
69 if ( (t3 = dynamic_cast<NumericalConditionNode*>(t1)) ) {
72 descendants[e] = new_node_id++;
74 for (ASTNode* child : e->children) {
79 CHECK_LE(ops.size(), 1);
80 *common_comp_op = ops.empty() ? Operator::kLT : *ops.begin();
86 common::ArrayFormatter formatter(80, 2);
89 std::string threshold;
90 int left_child_id, right_child_id;
91 unsigned int split_index;
93 NumericalConditionNode* t2;
94 CategoricalConditionNode* t3;
96 std::queue<ASTNode*> Q;
97 Q.push(node->children[0]);
99 ASTNode* e = Q.front(); Q.pop();
100 if ( (t1 = dynamic_cast<OutputNode*>(e)) ) {
101 output_nodes.push_back(t1);
104 CHECK_EQ(e->children.size(), 2U);
105 left_child_id = descendants[ e->children[0] ];
106 right_child_id = descendants[ e->children[1] ];
107 if ( (t2 = dynamic_cast<NumericalConditionNode*>(e)) ) {
108 default_left = t2->default_left;
109 split_index = t2->split_index;
111 = quantize ? std::to_string(t2->threshold.int_val)
112 : common::ToStringHighPrecision(t2->threshold.float_val);
114 CHECK((t3 = dynamic_cast<CategoricalConditionNode*>(e)));
115 default_left = t3->default_left;
116 split_index = t3->split_index;
118 CHECK(!t3->convert_missing_to_zero)
119 <<
"Code folding not supported, because a categorical split " 120 <<
"is supposed to convert missing values into zeros, and this " 121 <<
"is not possible with current code folding implementation.";
122 std::vector<uint64_t> bitmap
123 = GetCategoricalBitmap(t3->left_categories);
124 cat_bitmap.insert(cat_bitmap.end(), bitmap.begin(), bitmap.end());
125 cat_begin.push_back(cat_bitmap.size());
127 const char* (*BoolWrapper)(bool);
128 if (use_boolean_literal) {
129 BoolWrapper = [](
bool x) {
return x ?
"true" :
"false"; };
131 BoolWrapper = [](
bool x) {
return x ?
"1" :
"0"; };
133 formatter << fmt::format(node_entry_template,
134 "default_left"_a = BoolWrapper(default_left),
135 "split_index"_a = split_index,
136 "threshold"_a = threshold,
137 "left_child"_a = left_child_id,
138 "right_child"_a = right_child_id);
140 for (ASTNode* child : e->children) {
144 *array_nodes = formatter.str();
147 if (cat_bitmap.empty()) {
148 *array_cat_bitmap =
"";
149 *array_cat_begin =
"";
152 common::ArrayFormatter formatter(80, 2);
153 for (uint64_t e : cat_bitmap) {
154 formatter << fmt::format(
"{:#X}", e);
156 *array_cat_bitmap = formatter.str();
159 common::ArrayFormatter formatter(80, 2);
160 for (
size_t e : cat_begin) {
163 *array_cat_begin = formatter.str();
167 *output_switch_statements =
"switch (nid) {\n";
168 for (OutputNode* e : output_nodes) {
169 const int node_id = descendants[
static_cast<ASTNode*
>(e)];
170 *output_switch_statements
171 += fmt::format(
" case {node_id}:\n" 174 "node_id"_a = node_id,
176 = common::IndentMultiLineString(RenderOutputStatement(e), 2));
178 *output_switch_statements +=
"}\n";
185 #endif // TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
Function to generate bitmaps for categorical splits.