treelite
code_folding_util.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
8 #define TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
9 
10 #include <dmlc/logging.h>
11 #include <treelite/common.h>
12 #include <fmt/format.h>
13 #include <queue>
14 #include <set>
15 #include <string>
16 #include <vector>
17 #include "../ast/ast.h"
18 #include "./categorical_bitmap.h"
19 
20 using namespace fmt::literals;
21 
22 namespace treelite {
23 namespace compiler {
24 namespace common_util {
25 
26 template <typename OutputFormatFunc>
27 inline void
28 RenderCodeFolderArrays(const CodeFolderNode* node,
29  bool quantize,
30  bool use_boolean_literal,
31  const char* node_entry_template,
32  OutputFormatFunc RenderOutputStatement,
33  std::string* array_nodes,
34  std::string* array_cat_bitmap,
35  std::string* array_cat_begin,
36  std::string* output_switch_statements,
37  Operator* common_comp_op) {
38  CHECK_EQ(node->children.size(), 1);
39  const int tree_id = node->children[0]->tree_id;
40  // list of descendants, with newly assigned ID's
41  std::unordered_map<ASTNode*, int> descendants;
42  // list of all OutputNode's among the descendants
43  std::vector<OutputNode*> output_nodes;
44  // two arrays used to store categorical split info
45  std::vector<uint64_t> cat_bitmap;
46  std::vector<size_t> cat_begin{0};
47 
48  // 1. Assign new continuous node ID's (0, 1, 2, ...) by traversing the
49  // subtree breadth-first
50  {
51  std::queue<ASTNode*> Q;
52  std::set<treelite::Operator> ops;
53  int new_node_id = 0;
54  int new_leaf_id = -1;
55  Q.push(node->children[0]);
56  while (!Q.empty()) {
57  ASTNode* e = Q.front(); Q.pop();
58  // sanity check: all descendants must have same tree_id
59  CHECK_EQ(e->tree_id, tree_id);
60  // sanity check: all descendants must be ConditionNode or OutputNode
61  ConditionNode* t1 = dynamic_cast<ConditionNode*>(e);
62  OutputNode* t2 = dynamic_cast<OutputNode*>(e);
63  NumericalConditionNode* t3;
64  CHECK(t1 || t2);
65  if (t2) { // e is OutputNode
66  descendants[e] = new_leaf_id--;
67  } else {
68  if ( (t3 = dynamic_cast<NumericalConditionNode*>(t1)) ) {
69  ops.insert(t3->op);
70  }
71  descendants[e] = new_node_id++;
72  }
73  for (ASTNode* child : e->children) {
74  Q.push(child);
75  }
76  }
77  // sanity check: all numerical splits must have identical comparison operators
78  CHECK_LE(ops.size(), 1);
79  *common_comp_op = ops.empty() ? Operator::kLT : *ops.begin();
80  }
81 
82  // 2. Render node_treeXX_nodeXX[] by traversing the subtree once again.
83  // Now we can use the re-assigned node ID's.
84  {
85  common::ArrayFormatter formatter(80, 2);
86 
87  bool default_left;
88  std::string threshold;
89  int left_child_id, right_child_id;
90  unsigned int split_index;
91  OutputNode* t1;
92  NumericalConditionNode* t2;
93  CategoricalConditionNode* t3;
94 
95  std::queue<ASTNode*> Q;
96  Q.push(node->children[0]);
97  while (!Q.empty()) {
98  ASTNode* e = Q.front(); Q.pop();
99  if ( (t1 = dynamic_cast<OutputNode*>(e)) ) {
100  output_nodes.push_back(t1);
101  // don't render OutputNode but save it for later
102  } else {
103  CHECK_EQ(e->children.size(), 2U);
104  left_child_id = descendants[ e->children[0] ];
105  right_child_id = descendants[ e->children[1] ];
106  if ( (t2 = dynamic_cast<NumericalConditionNode*>(e)) ) {
107  default_left = t2->default_left;
108  split_index = t2->split_index;
109  threshold
110  = quantize ? std::to_string(t2->threshold.int_val)
111  : common::ToStringHighPrecision(t2->threshold.float_val);
112  } else {
113  CHECK((t3 = dynamic_cast<CategoricalConditionNode*>(e)));
114  default_left = t3->default_left;
115  split_index = t3->split_index;
116  threshold = "-1"; // dummy value
117  std::vector<uint64_t> bitmap
118  = GetCategoricalBitmap(t3->left_categories);
119  cat_bitmap.insert(cat_bitmap.end(), bitmap.begin(), bitmap.end());
120  cat_begin.push_back(cat_bitmap.size());
121  }
122  auto BoolWrapper =
123  use_boolean_literal ? [](bool x) { return x ? "true" : "false"; }
124  : [](bool x) { return x ? "1" : "0"; };
125  formatter << fmt::format(node_entry_template,
126  "default_left"_a = BoolWrapper(default_left),
127  "split_index"_a = split_index,
128  "threshold"_a = threshold,
129  "left_child"_a = left_child_id,
130  "right_child"_a = right_child_id);
131  }
132  for (ASTNode* child : e->children) {
133  Q.push(child);
134  }
135  }
136  *array_nodes = formatter.str();
137  }
138  // 3. render cat_bitmap_treeXX_nodeXX[] and cat_begin_treeXX_nodeXX[]
139  if (cat_bitmap.empty()) { // do not render empty arrays
140  *array_cat_bitmap = "";
141  *array_cat_begin = "";
142  } else {
143  {
144  common::ArrayFormatter formatter(80, 2);
145  for (uint64_t e : cat_bitmap) {
146  formatter << fmt::format("{:#X}", e);
147  }
148  *array_cat_bitmap = formatter.str();
149  }
150  {
151  common::ArrayFormatter formatter(80, 2);
152  for (size_t e : cat_begin) {
153  formatter << e;
154  }
155  *array_cat_begin = formatter.str();
156  }
157  }
158  // 4. Render switch statement to associate each node ID with an output
159  *output_switch_statements = "switch (nid) {\n";
160  for (OutputNode* e : output_nodes) {
161  const int node_id = descendants[static_cast<ASTNode*>(e)];
162  *output_switch_statements
163  += fmt::format(" case {node_id}:\n"
164  "{output_statement}"
165  " break;\n",
166  "node_id"_a = node_id,
167  "output_statement"_a
168  = common::IndentMultiLineString(RenderOutputStatement(e), 2));
169  }
170  *output_switch_statements += "}\n";
171 }
172 
173 } // namespace common_util
174 } // namespace compiler
175 } // namespace treelite
176 
177 #endif // TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
Function to generate bitmaps for categorical splits.