treelite
code_folding_util.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
8 #define TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/common.h>
12 #include <fmt/format.h>
13 #include <queue>
14 #include <set>
15 #include <string>
16 #include <vector>
17 #include <unordered_map>
18 #include "../ast/ast.h"
19 #include "./categorical_bitmap.h"
20 
21 using namespace fmt::literals;
22 
23 namespace treelite {
24 namespace compiler {
25 namespace common_util {
26 
27 template <typename OutputFormatFunc>
28 inline void
29 RenderCodeFolderArrays(const CodeFolderNode* node,
30  bool quantize,
31  bool use_boolean_literal,
32  const char* node_entry_template,
33  OutputFormatFunc RenderOutputStatement,
34  std::string* array_nodes,
35  std::string* array_cat_bitmap,
36  std::string* array_cat_begin,
37  std::string* output_switch_statements,
38  Operator* common_comp_op) {
39  CHECK_EQ(node->children.size(), 1);
40  const int tree_id = node->children[0]->tree_id;
41  // list of descendants, with newly assigned ID's
42  std::unordered_map<ASTNode*, int> descendants;
43  // list of all OutputNode's among the descendants
44  std::vector<OutputNode*> output_nodes;
45  // two arrays used to store categorical split info
46  std::vector<uint64_t> cat_bitmap;
47  std::vector<size_t> cat_begin{0};
48 
49  // 1. Assign new continuous node ID's (0, 1, 2, ...) by traversing the
50  // subtree breadth-first
51  {
52  std::queue<ASTNode*> Q;
53  std::set<treelite::Operator> ops;
54  int new_node_id = 0;
55  int new_leaf_id = -1;
56  Q.push(node->children[0]);
57  while (!Q.empty()) {
58  ASTNode* e = Q.front(); Q.pop();
59  // sanity check: all descendants must have same tree_id
60  CHECK_EQ(e->tree_id, tree_id);
61  // sanity check: all descendants must be ConditionNode or OutputNode
62  ConditionNode* t1 = dynamic_cast<ConditionNode*>(e);
63  OutputNode* t2 = dynamic_cast<OutputNode*>(e);
64  NumericalConditionNode* t3;
65  CHECK(t1 || t2);
66  if (t2) { // e is OutputNode
67  descendants[e] = new_leaf_id--;
68  } else {
69  if ( (t3 = dynamic_cast<NumericalConditionNode*>(t1)) ) {
70  ops.insert(t3->op);
71  }
72  descendants[e] = new_node_id++;
73  }
74  for (ASTNode* child : e->children) {
75  Q.push(child);
76  }
77  }
78  // sanity check: all numerical splits must have identical comparison operators
79  CHECK_LE(ops.size(), 1);
80  *common_comp_op = ops.empty() ? Operator::kLT : *ops.begin();
81  }
82 
83  // 2. Render node_treeXX_nodeXX[] by traversing the subtree once again.
84  // Now we can use the re-assigned node ID's.
85  {
86  common::ArrayFormatter formatter(80, 2);
87 
88  bool default_left;
89  std::string threshold;
90  int left_child_id, right_child_id;
91  unsigned int split_index;
92  OutputNode* t1;
93  NumericalConditionNode* t2;
94  CategoricalConditionNode* t3;
95 
96  std::queue<ASTNode*> Q;
97  Q.push(node->children[0]);
98  while (!Q.empty()) {
99  ASTNode* e = Q.front(); Q.pop();
100  if ( (t1 = dynamic_cast<OutputNode*>(e)) ) {
101  output_nodes.push_back(t1);
102  // don't render OutputNode but save it for later
103  } else {
104  CHECK_EQ(e->children.size(), 2U);
105  left_child_id = descendants[ e->children[0] ];
106  right_child_id = descendants[ e->children[1] ];
107  if ( (t2 = dynamic_cast<NumericalConditionNode*>(e)) ) {
108  default_left = t2->default_left;
109  split_index = t2->split_index;
110  threshold
111  = quantize ? std::to_string(t2->threshold.int_val)
112  : common::ToStringHighPrecision(t2->threshold.float_val);
113  } else {
114  CHECK((t3 = dynamic_cast<CategoricalConditionNode*>(e)));
115  default_left = t3->default_left;
116  split_index = t3->split_index;
117  threshold = "-1"; // dummy value
118  CHECK(!t3->convert_missing_to_zero)
119  << "Code folding not supported, because a categorical split "
120  << "is supposed to convert missing values into zeros, and this "
121  << "is not possible with current code folding implementation.";
122  std::vector<uint64_t> bitmap
123  = GetCategoricalBitmap(t3->left_categories);
124  cat_bitmap.insert(cat_bitmap.end(), bitmap.begin(), bitmap.end());
125  cat_begin.push_back(cat_bitmap.size());
126  }
127  const char* (*BoolWrapper)(bool);
128  if (use_boolean_literal) {
129  BoolWrapper = [](bool x) { return x ? "true" : "false"; };
130  } else {
131  BoolWrapper = [](bool x) { return x ? "1" : "0"; };
132  }
133  formatter << fmt::format(node_entry_template,
134  "default_left"_a = BoolWrapper(default_left),
135  "split_index"_a = split_index,
136  "threshold"_a = threshold,
137  "left_child"_a = left_child_id,
138  "right_child"_a = right_child_id);
139  }
140  for (ASTNode* child : e->children) {
141  Q.push(child);
142  }
143  }
144  *array_nodes = formatter.str();
145  }
146  // 3. render cat_bitmap_treeXX_nodeXX[] and cat_begin_treeXX_nodeXX[]
147  if (cat_bitmap.empty()) { // do not render empty arrays
148  *array_cat_bitmap = "";
149  *array_cat_begin = "";
150  } else {
151  {
152  common::ArrayFormatter formatter(80, 2);
153  for (uint64_t e : cat_bitmap) {
154  formatter << fmt::format("{:#X}", e);
155  }
156  *array_cat_bitmap = formatter.str();
157  }
158  {
159  common::ArrayFormatter formatter(80, 2);
160  for (size_t e : cat_begin) {
161  formatter << e;
162  }
163  *array_cat_begin = formatter.str();
164  }
165  }
166  // 4. Render switch statement to associate each node ID with an output
167  *output_switch_statements = "switch (nid) {\n";
168  for (OutputNode* e : output_nodes) {
169  const int node_id = descendants[static_cast<ASTNode*>(e)];
170  *output_switch_statements
171  += fmt::format(" case {node_id}:\n"
172  "{output_statement}"
173  " break;\n",
174  "node_id"_a = node_id,
175  "output_statement"_a
176  = common::IndentMultiLineString(RenderOutputStatement(e), 2));
177  }
178  *output_switch_statements += "}\n";
179 }
180 
181 } // namespace common_util
182 } // namespace compiler
183 } // namespace treelite
184 
185 #endif // TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
Function to generate bitmaps for categorical splits.