7 #ifndef TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_ 8 #define TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_ 10 #include <fmt/format.h> 12 #include <unordered_map> 17 #include "../ast/ast.h" 25 namespace common_util {
28 template <
typename ThresholdType,
typename LeafOutputType,
typename OutputFormatFunc>
30 RenderCodeFolderArrays(
const CodeFolderNode* node,
32 bool use_boolean_literal,
33 const char* node_entry_template,
34 OutputFormatFunc RenderOutputStatement,
35 std::string* array_nodes,
36 std::string* array_cat_bitmap,
37 std::string* array_cat_begin,
38 std::string* output_switch_statements,
39 Operator* common_comp_op) {
40 TREELITE_CHECK_EQ(node->children.size(), 1);
41 const int tree_id = node->children[0]->tree_id;
43 std::unordered_map<ASTNode*, int> descendants;
45 std::vector<OutputNode<LeafOutputType>*> output_nodes;
47 std::vector<uint64_t> cat_bitmap;
48 std::vector<size_t> cat_begin{0};
53 std::queue<ASTNode*> Q;
54 std::set<treelite::Operator> ops;
57 Q.push(node->children[0]);
59 ASTNode* e = Q.front(); Q.pop();
61 TREELITE_CHECK_EQ(e->tree_id, tree_id);
63 ConditionNode* t1 =
dynamic_cast<ConditionNode*
>(e);
64 OutputNode<LeafOutputType>* t2 =
dynamic_cast<OutputNode<LeafOutputType>*
>(e);
65 NumericalConditionNode<ThresholdType>* t3;
66 TREELITE_CHECK(t1 || t2);
68 descendants[e] = new_leaf_id--;
70 if ( (t3 =
dynamic_cast<NumericalConditionNode<ThresholdType>*
>(t1)) ) {
73 descendants[e] = new_node_id++;
75 for (ASTNode* child : e->children) {
80 TREELITE_CHECK_LE(ops.size(), 1);
81 *common_comp_op = ops.empty() ? Operator::kLT : *ops.begin();
87 ArrayFormatter formatter(80, 2);
90 std::string threshold;
91 int left_child_id, right_child_id;
92 unsigned int split_index;
93 OutputNode<LeafOutputType>* t1;
94 NumericalConditionNode<ThresholdType>* t2;
95 CategoricalConditionNode* t3;
97 std::queue<ASTNode*> Q;
98 Q.push(node->children[0]);
100 ASTNode* e = Q.front(); Q.pop();
101 if ( (t1 =
dynamic_cast<OutputNode<LeafOutputType>*
>(e)) ) {
102 output_nodes.push_back(t1);
105 TREELITE_CHECK_EQ(e->children.size(), 2U);
106 left_child_id = descendants[ e->children[0] ];
107 right_child_id = descendants[ e->children[1] ];
108 if ( (t2 =
dynamic_cast<NumericalConditionNode<ThresholdType>*
>(e)) ) {
109 default_left = t2->default_left;
110 split_index = t2->split_index;
112 = quantize ? std::to_string(t2->threshold.int_val)
115 TREELITE_CHECK((t3 = dynamic_cast<CategoricalConditionNode*>(e)));
116 default_left = t3->default_left;
117 split_index = t3->split_index;
119 std::vector<uint64_t> bitmap = GetCategoricalBitmap(t3->matching_categories);
120 cat_bitmap.insert(cat_bitmap.end(), bitmap.begin(), bitmap.end());
121 cat_begin.push_back(cat_bitmap.size());
123 const char* (*BoolWrapper)(bool);
124 if (use_boolean_literal) {
125 BoolWrapper = [](
bool x) {
return x ?
"true" :
"false"; };
127 BoolWrapper = [](
bool x) {
return x ?
"1" :
"0"; };
129 formatter << fmt::format(node_entry_template,
130 "default_left"_a = BoolWrapper(default_left),
131 "split_index"_a = split_index,
132 "threshold"_a = threshold,
133 "left_child"_a = left_child_id,
134 "right_child"_a = right_child_id);
136 for (ASTNode* child : e->children) {
140 *array_nodes = formatter.str();
143 if (cat_bitmap.empty()) {
144 *array_cat_bitmap =
"";
145 *array_cat_begin =
"";
148 ArrayFormatter formatter(80, 2);
149 for (uint64_t e : cat_bitmap) {
150 formatter << fmt::format(
"{:#X}", e);
152 *array_cat_bitmap = formatter.str();
155 ArrayFormatter formatter(80, 2);
156 for (
size_t e : cat_begin) {
159 *array_cat_begin = formatter.str();
163 *output_switch_statements =
"switch (nid) {\n";
164 for (OutputNode<LeafOutputType>* e : output_nodes) {
165 const int node_id = descendants[
static_cast<ASTNode*
>(e)];
166 *output_switch_statements
167 += fmt::format(
" case {node_id}:\n" 170 "node_id"_a = node_id,
173 *output_switch_statements +=
"}\n";
180 #endif // TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
logging facility for Treelite
Function to generate bitmaps for categorical splits.