Treelite
code_folding_util.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
8 #define TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
9 
10 #include <fmt/format.h>
11 #include <treelite/logging.h>
12 #include <unordered_map>
13 #include <queue>
14 #include <set>
15 #include <string>
16 #include <vector>
17 #include "../ast/ast.h"
18 #include "./format_util.h"
19 #include "./categorical_bitmap.h"
20 
21 using namespace fmt::literals;
22 
23 namespace treelite {
24 namespace compiler {
25 namespace common_util {
26 
27 
28 template <typename ThresholdType, typename LeafOutputType, typename OutputFormatFunc>
29 inline void
30 RenderCodeFolderArrays(const CodeFolderNode* node,
31  bool quantize,
32  bool use_boolean_literal,
33  const char* node_entry_template,
34  OutputFormatFunc RenderOutputStatement,
35  std::string* array_nodes,
36  std::string* array_cat_bitmap,
37  std::string* array_cat_begin,
38  std::string* output_switch_statements,
39  Operator* common_comp_op) {
40  TREELITE_CHECK_EQ(node->children.size(), 1);
41  const int tree_id = node->children[0]->tree_id;
42  // list of descendants, with newly assigned ID's
43  std::unordered_map<ASTNode*, int> descendants;
44  // list of all OutputNode's among the descendants
45  std::vector<OutputNode<LeafOutputType>*> output_nodes;
46  // two arrays used to store categorical split info
47  std::vector<uint64_t> cat_bitmap;
48  std::vector<size_t> cat_begin{0};
49 
50  // 1. Assign new continuous node ID's (0, 1, 2, ...) by traversing the
51  // subtree breadth-first
52  {
53  std::queue<ASTNode*> Q;
54  std::set<treelite::Operator> ops;
55  int new_node_id = 0;
56  int new_leaf_id = -1;
57  Q.push(node->children[0]);
58  while (!Q.empty()) {
59  ASTNode* e = Q.front(); Q.pop();
60  // sanity check: all descendants must have same tree_id
61  TREELITE_CHECK_EQ(e->tree_id, tree_id);
62  // sanity check: all descendants must be ConditionNode or OutputNode
63  ConditionNode* t1 = dynamic_cast<ConditionNode*>(e);
64  OutputNode<LeafOutputType>* t2 = dynamic_cast<OutputNode<LeafOutputType>*>(e);
65  NumericalConditionNode<ThresholdType>* t3;
66  TREELITE_CHECK(t1 || t2);
67  if (t2) { // e is OutputNode
68  descendants[e] = new_leaf_id--;
69  } else {
70  if ( (t3 = dynamic_cast<NumericalConditionNode<ThresholdType>*>(t1)) ) {
71  ops.insert(t3->op);
72  }
73  descendants[e] = new_node_id++;
74  }
75  for (ASTNode* child : e->children) {
76  Q.push(child);
77  }
78  }
79  // sanity check: all numerical splits must have identical comparison operators
80  TREELITE_CHECK_LE(ops.size(), 1);
81  *common_comp_op = ops.empty() ? Operator::kLT : *ops.begin();
82  }
83 
84  // 2. Render node_treeXX_nodeXX[] by traversing the subtree once again.
85  // Now we can use the re-assigned node ID's.
86  {
87  ArrayFormatter formatter(80, 2);
88 
89  bool default_left;
90  std::string threshold;
91  int left_child_id, right_child_id;
92  unsigned int split_index;
93  OutputNode<LeafOutputType>* t1;
94  NumericalConditionNode<ThresholdType>* t2;
95  CategoricalConditionNode* t3;
96 
97  std::queue<ASTNode*> Q;
98  Q.push(node->children[0]);
99  while (!Q.empty()) {
100  ASTNode* e = Q.front(); Q.pop();
101  if ( (t1 = dynamic_cast<OutputNode<LeafOutputType>*>(e)) ) {
102  output_nodes.push_back(t1);
103  // don't render OutputNode but save it for later
104  } else {
105  TREELITE_CHECK_EQ(e->children.size(), 2U);
106  left_child_id = descendants[ e->children[0] ];
107  right_child_id = descendants[ e->children[1] ];
108  if ( (t2 = dynamic_cast<NumericalConditionNode<ThresholdType>*>(e)) ) {
109  default_left = t2->default_left;
110  split_index = t2->split_index;
111  threshold
112  = quantize ? std::to_string(t2->threshold.int_val)
113  : ToStringHighPrecision(t2->threshold.float_val);
114  } else {
115  TREELITE_CHECK((t3 = dynamic_cast<CategoricalConditionNode*>(e)));
116  default_left = t3->default_left;
117  split_index = t3->split_index;
118  threshold = "-1"; // dummy value
119  std::vector<uint64_t> bitmap = GetCategoricalBitmap(t3->matching_categories);
120  cat_bitmap.insert(cat_bitmap.end(), bitmap.begin(), bitmap.end());
121  cat_begin.push_back(cat_bitmap.size());
122  }
123  const char* (*BoolWrapper)(bool);
124  if (use_boolean_literal) {
125  BoolWrapper = [](bool x) { return x ? "true" : "false"; };
126  } else {
127  BoolWrapper = [](bool x) { return x ? "1" : "0"; };
128  }
129  formatter << fmt::format(node_entry_template,
130  "default_left"_a = BoolWrapper(default_left),
131  "split_index"_a = split_index,
132  "threshold"_a = threshold,
133  "left_child"_a = left_child_id,
134  "right_child"_a = right_child_id);
135  }
136  for (ASTNode* child : e->children) {
137  Q.push(child);
138  }
139  }
140  *array_nodes = formatter.str();
141  }
142  // 3. render cat_bitmap_treeXX_nodeXX[] and cat_begin_treeXX_nodeXX[]
143  if (cat_bitmap.empty()) { // do not render empty arrays
144  *array_cat_bitmap = "";
145  *array_cat_begin = "";
146  } else {
147  {
148  ArrayFormatter formatter(80, 2);
149  for (uint64_t e : cat_bitmap) {
150  formatter << fmt::format("{:#X}", e);
151  }
152  *array_cat_bitmap = formatter.str();
153  }
154  {
155  ArrayFormatter formatter(80, 2);
156  for (size_t e : cat_begin) {
157  formatter << e;
158  }
159  *array_cat_begin = formatter.str();
160  }
161  }
162  // 4. Render switch statement to associate each node ID with an output
163  *output_switch_statements = "switch (nid) {\n";
164  for (OutputNode<LeafOutputType>* e : output_nodes) {
165  const int node_id = descendants[static_cast<ASTNode*>(e)];
166  *output_switch_statements
167  += fmt::format(" case {node_id}:\n"
168  "{output_statement}"
169  " break;\n",
170  "node_id"_a = node_id,
171  "output_statement"_a = IndentMultiLineString(RenderOutputStatement(e), 2));
172  }
173  *output_switch_statements += "}\n";
174 }
175 
176 } // namespace common_util
177 } // namespace compiler
178 } // namespace treelite
179 
180 #endif // TREELITE_COMPILER_COMMON_CODE_FOLDING_UTIL_H_
std::string IndentMultiLineString(const std::string &str, size_t indent=2)
apply indentation to a multi-line string by inserting spaces at the beginning of each line ...
Definition: format_util.h:26
logging facility for Treelite
Function to generate bitmaps for categorical splits.
Formatting utilities.
std::string ToStringHighPrecision(T value)
obtain a string representation of floating-point value, expressed in high precision ...
Definition: format_util.h:53