treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <treelite/compiler.h>
8 #include <treelite/common.h>
9 #include <treelite/annotator.h>
10 #include <fmt/format.h>
11 #include <algorithm>
12 #include <unordered_map>
13 #include <queue>
14 #include <cmath>
15 #include "./param.h"
16 #include "./pred_transform.h"
17 #include "./ast/builder.h"
18 #include "./native/main_template.h"
24 
25 using namespace fmt::literals;
26 
27 namespace treelite {
28 namespace compiler {
29 
30 DMLC_REGISTRY_FILE_TAG(ast_native);
31 
32 class ASTNativeCompiler : public Compiler {
33  public:
34  explicit ASTNativeCompiler(const CompilerParam& param)
35  : param(param) {
36  if (param.verbose > 0) {
37  LOG(INFO) << "Using ASTNativeCompiler";
38  }
39  }
40 
41  CompiledModel Compile(const Model& model) override {
42  CompiledModel cm;
43  cm.backend = "native";
44  cm.files["main.c"] = "";
45 
46  num_feature_ = model.num_feature;
47  num_output_group_ = model.num_output_group;
48  pred_tranform_func_ = PredTransformFunction("native", model);
49  files_.clear();
50 
51  ASTBuilder builder;
52  builder.BuildAST(model);
53  if (builder.FoldCode(param.code_folding_req)
54  || param.quantize > 0) {
55  // is_categorical[i] : is i-th feature categorical?
56  array_is_categorical_
57  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
58  }
59  if (param.annotate_in != "NULL") {
60  BranchAnnotator annotator;
61  std::unique_ptr<dmlc::Stream> fi(
62  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
63  annotator.Load(fi.get());
64  const auto annotation = annotator.Get();
65  builder.LoadDataCounts(annotation);
66  LOG(INFO) << "Loading node frequencies from `"
67  << param.annotate_in << "'";
68  }
69  builder.Split(param.parallel_comp);
70  if (param.quantize > 0) {
71  builder.QuantizeThresholds();
72  }
73  WalkAST(builder.GetRootNode(), "main.c", 0);
74  if (files_.count("arrays.c") > 0) {
75  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
76  }
77 
78  {
79  /* write recipe.json */
80  std::vector<std::unordered_map<std::string, std::string>> source_list;
81  for (auto kv : files_) {
82  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
83  const size_t line_count
84  = std::count(kv.second.begin(), kv.second.end(), '\n');
85  source_list.push_back({ {"name",
86  kv.first.substr(0, kv.first.length() - 2)},
87  {"length", std::to_string(line_count)} });
88  }
89  }
90  std::ostringstream oss;
91  auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
92  writer->BeginObject();
93  writer->WriteObjectKeyValue("target", param.native_lib_name);
94  writer->WriteObjectKeyValue("sources", source_list);
95  writer->EndObject();
96  files_["recipe.json"] = oss.str();
97  }
98  cm.files = std::move(files_);
99  return cm;
100  }
101 
102  private:
103  CompilerParam param;
104  int num_feature_;
105  int num_output_group_;
106  std::string pred_tranform_func_;
107  std::string array_is_categorical_;
108  std::unordered_map<std::string, std::string> files_;
109 
110  void WalkAST(const ASTNode* node,
111  const std::string& dest,
112  size_t indent) {
113  const MainNode* t1;
114  const AccumulatorContextNode* t2;
115  const ConditionNode* t3;
116  const OutputNode* t4;
117  const TranslationUnitNode* t5;
118  const QuantizerNode* t6;
119  const CodeFolderNode* t7;
120  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
121  HandleMainNode(t1, dest, indent);
122  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
123  HandleACNode(t2, dest, indent);
124  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
125  HandleCondNode(t3, dest, indent);
126  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
127  HandleOutputNode(t4, dest, indent);
128  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
129  HandleTUNode(t5, dest, indent);
130  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
131  HandleQNode(t6, dest, indent);
132  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
133  HandleCodeFolderNode(t7, dest, indent);
134  } else {
135  LOG(FATAL) << "Unrecognized AST node type";
136  }
137  }
138 
139  // append content to a given buffer, with given level of indentation
140  inline void AppendToBuffer(const std::string& dest,
141  const std::string& content,
142  size_t indent) {
143  files_[dest] += common::IndentMultiLineString(content, indent);
144  }
145 
146  // prepend content to a given buffer, with given level of indentation
147  inline void PrependToBuffer(const std::string& dest,
148  const std::string& content,
149  size_t indent) {
150  files_[dest]
151  = common::IndentMultiLineString(content, indent) + files_[dest];
152  }
153 
154  void HandleMainNode(const MainNode* node,
155  const std::string& dest,
156  size_t indent) {
157  const char* get_num_output_group_function_signature
158  = "size_t get_num_output_group(void)";
159  const char* get_num_feature_function_signature
160  = "size_t get_num_feature(void)";
161  const char* predict_function_signature
162  = (num_output_group_ > 1) ?
163  "size_t predict_multiclass(union Entry* data, int pred_margin, "
164  "float* result)"
165  : "float predict(union Entry* data, int pred_margin)";
166 
167  AppendToBuffer(dest,
168  fmt::format(native::main_start_template,
169  "array_is_categorical"_a = array_is_categorical_,
170  "get_num_output_group_function_signature"_a
171  = get_num_output_group_function_signature,
172  "get_num_feature_function_signature"_a
173  = get_num_feature_function_signature,
174  "pred_transform_function"_a = pred_tranform_func_,
175  "predict_function_signature"_a = predict_function_signature,
176  "num_output_group"_a = num_output_group_,
177  "num_feature"_a = node->num_feature),
178  indent);
179  AppendToBuffer("header.h",
180  fmt::format(native::header_template,
181  "get_num_output_group_function_signature"_a
182  = get_num_output_group_function_signature,
183  "get_num_feature_function_signature"_a
184  = get_num_feature_function_signature,
185  "predict_function_signature"_a = predict_function_signature,
186  "threshold_type"_a = (param.quantize > 0 ? "int" : "double")),
187  indent);
188 
189  CHECK_EQ(node->children.size(), 1);
190  WalkAST(node->children[0], dest, indent + 2);
191 
192  const std::string optional_average_field
193  = (node->average_result) ? fmt::format(" / {}", node->num_tree)
194  : std::string("");
195  if (num_output_group_ > 1) {
196  AppendToBuffer(dest,
197  fmt::format(native::main_end_multiclass_template,
198  "num_output_group"_a = num_output_group_,
199  "optional_average_field"_a = optional_average_field,
200  "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
201  indent);
202  } else {
203  AppendToBuffer(dest,
204  fmt::format(native::main_end_template,
205  "optional_average_field"_a = optional_average_field,
206  "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
207  indent);
208  }
209  }
210 
211  void HandleACNode(const AccumulatorContextNode* node,
212  const std::string& dest,
213  size_t indent) {
214  if (num_output_group_ > 1) {
215  AppendToBuffer(dest,
216  fmt::format("float sum[{num_output_group}] = {{0.0f}};\n"
217  "unsigned int tmp;\n"
218  "int nid, cond, fid; /* used for folded subtrees */\n",
219  "num_output_group"_a = num_output_group_), indent);
220  } else {
221  AppendToBuffer(dest,
222  "float sum = 0.0f;\n"
223  "unsigned int tmp;\n"
224  "int nid, cond, fid; /* used for folded subtrees */\n", indent);
225  }
226  for (ASTNode* child : node->children) {
227  WalkAST(child, dest, indent);
228  }
229  }
230 
231  void HandleCondNode(const ConditionNode* node,
232  const std::string& dest,
233  size_t indent) {
234  const NumericalConditionNode* t;
235  std::string condition;
236  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
237  /* numerical split */
238  condition = ExtractNumericalCondition(t);
239  } else { /* categorical split */
240  const CategoricalConditionNode* t2
241  = dynamic_cast<const CategoricalConditionNode*>(node);
242  CHECK(t2);
243  condition = ExtractCategoricalCondition(t2);
244  }
245  const char* condition_with_na_check_template
246  = (node->default_left) ?
247  "!(data[{split_index}].missing != -1) || ({condition})"
248  : " (data[{split_index}].missing != -1) && ({condition})";
249  std::string condition_with_na_check
250  = fmt::format(condition_with_na_check_template,
251  "split_index"_a = node->split_index,
252  "condition"_a = condition);
253  if (node->children[0]->data_count && node->children[1]->data_count) {
254  const int left_freq = node->children[0]->data_count.value();
255  const int right_freq = node->children[1]->data_count.value();
256  condition_with_na_check
257  = fmt::format(" {keyword}( {condition} ) ",
258  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
259  "condition"_a = condition_with_na_check);
260  }
261  AppendToBuffer(dest,
262  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
263  CHECK_EQ(node->children.size(), 2);
264  WalkAST(node->children[0], dest, indent + 2);
265  AppendToBuffer(dest, "} else {\n", indent);
266  WalkAST(node->children[1], dest, indent + 2);
267  AppendToBuffer(dest, "}\n", indent);
268  }
269 
270  void HandleOutputNode(const OutputNode* node,
271  const std::string& dest,
272  size_t indent) {
273  AppendToBuffer(dest, RenderOutputStatement(node), indent);
274  CHECK_EQ(node->children.size(), 0);
275  }
276 
277  void HandleTUNode(const TranslationUnitNode* node,
278  const std::string& dest,
279  int indent) {
280  const int unit_id = node->unit_id;
281  const std::string new_file = fmt::format("tu{}.c", unit_id);
282 
283  std::string unit_function_name, unit_function_signature,
284  unit_function_call_signature;
285  if (num_output_group_ > 1) {
286  unit_function_name
287  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
288  unit_function_signature
289  = fmt::format("void {}(union Entry* data, float* result)",
290  unit_function_name);
291  unit_function_call_signature
292  = fmt::format("{}(data, sum);\n", unit_function_name);
293  } else {
294  unit_function_name
295  = fmt::format("predict_margin_unit{}", unit_id);
296  unit_function_signature
297  = fmt::format("float {}(union Entry* data)", unit_function_name);
298  unit_function_call_signature
299  = fmt::format("sum += {}(data);\n", unit_function_name);
300  }
301  AppendToBuffer(dest, unit_function_call_signature, indent);
302  AppendToBuffer(new_file,
303  fmt::format("#include \"header.h\"\n"
304  "{} {{\n", unit_function_signature), 0);
305  CHECK_EQ(node->children.size(), 1);
306  WalkAST(node->children[0], new_file, 2);
307  if (num_output_group_ > 1) {
308  AppendToBuffer(new_file,
309  fmt::format(" for (int i = 0; i < {num_output_group}; ++i) {{\n"
310  " result[i] += sum[i];\n"
311  " }}\n"
312  "}}\n",
313  "num_output_group"_a = num_output_group_), 0);
314  } else {
315  AppendToBuffer(new_file, " return sum;\n}\n", 0);
316  }
317  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
318  }
319 
320  void HandleQNode(const QuantizerNode* node,
321  const std::string& dest,
322  size_t indent) {
323  /* render arrays needed to convert feature values into bin indices */
324  std::string array_threshold, array_th_begin, array_th_len;
325  // threshold[] : list of all thresholds that occur at least once in the
326  // ensemble model. For each feature, an ascending list of unique
327  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
328  // of the threshold[] array stores the threshold list for feature i.
329  size_t total_num_threshold;
330  // to hold total number of (distinct) thresholds
331  {
332  common::ArrayFormatter formatter(80, 2);
333  for (const auto& e : node->cut_pts) {
334  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
335  // cut_pts[i][k] stores the k-th threshold of feature i.
336  for (tl_float v : e) {
337  formatter << v;
338  }
339  }
340  array_threshold = formatter.str();
341  }
342  {
343  common::ArrayFormatter formatter(80, 2);
344  size_t accum = 0; // used to compute cumulative sum over threshold counts
345  for (const auto& e : node->cut_pts) {
346  formatter << accum;
347  accum += e.size(); // e.size() = number of thresholds for each feature
348  }
349  total_num_threshold = accum;
350  array_th_begin = formatter.str();
351  }
352  {
353  common::ArrayFormatter formatter(80, 2);
354  for (const auto& e : node->cut_pts) {
355  formatter << e.size();
356  }
357  array_th_len = formatter.str();
358  }
359  PrependToBuffer(dest,
360  fmt::format(native::qnode_template,
361  "array_threshold"_a = array_threshold,
362  "array_th_begin"_a = array_th_begin,
363  "array_th_len"_a = array_th_len,
364  "total_num_threshold"_a = total_num_threshold), 0);
365  AppendToBuffer(dest,
366  fmt::format(native::quantize_loop_template,
367  "num_feature"_a = num_feature_), indent);
368  CHECK_EQ(node->children.size(), 1);
369  WalkAST(node->children[0], dest, indent);
370  }
371 
372  void HandleCodeFolderNode(const CodeFolderNode* node,
373  const std::string& dest,
374  size_t indent) {
375  CHECK_EQ(node->children.size(), 1);
376  const int node_id = node->children[0]->node_id;
377  const int tree_id = node->children[0]->tree_id;
378 
379  /* render arrays needed for folding subtrees */
380  std::string array_nodes, array_cat_bitmap, array_cat_begin;
381  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
382  const std::string node_array_name
383  = fmt::format("node_tree{}_node{}", tree_id, node_id);
384  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
385  // make all categorical splits in a particular
386  // subtree
387  const std::string cat_bitmap_name
388  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
389  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
390  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
391  // belongs to the i-th (categorical) split
392  const std::string cat_begin_name
393  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
394 
395  std::string output_switch_statement;
396  Operator common_comp_op;
397  common_util::RenderCodeFolderArrays(node, param.quantize, false,
398  "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
399  [this](const OutputNode* node) { return RenderOutputStatement(node); },
400  &array_nodes, &array_cat_bitmap, &array_cat_begin,
401  &output_switch_statement, &common_comp_op);
402 
403  AppendToBuffer("header.h",
404  fmt::format(native::code_folder_arrays_declaration_template,
405  "node_array_name"_a = node_array_name,
406  "cat_bitmap_name"_a = cat_bitmap_name,
407  "cat_begin_name"_a = cat_begin_name), 0);
408  AppendToBuffer("arrays.c",
409  fmt::format(native::code_folder_arrays_template,
410  "node_array_name"_a = node_array_name,
411  "array_nodes"_a = array_nodes,
412  "cat_bitmap_name"_a = cat_bitmap_name,
413  "array_cat_bitmap"_a = array_cat_bitmap,
414  "cat_begin_name"_a = cat_begin_name,
415  "array_cat_begin"_a = array_cat_begin), 0);
416  if (array_nodes.empty()) {
417  /* folded code consists of a single leaf node */
418  AppendToBuffer(dest,
419  fmt::format("nid = -1;\n"
420  "{output_switch_statement}\n",
421  "output_switch_statement"_a
422  = output_switch_statement), indent);
423  } else {
424  AppendToBuffer(dest,
425  fmt::format(native::eval_loop_template,
426  "node_array_name"_a = node_array_name,
427  "cat_bitmap_name"_a = cat_bitmap_name,
428  "cat_begin_name"_a = cat_begin_name,
429  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
430  "comp_op"_a = OpName(common_comp_op),
431  "output_switch_statement"_a
432  = output_switch_statement), indent);
433  }
434  }
435 
436  inline std::string
437  ExtractNumericalCondition(const NumericalConditionNode* node) {
438  std::string result;
439  if (node->quantized) { // quantized threshold
440  result = fmt::format("data[{split_index}].qvalue {opname} {threshold}",
441  "split_index"_a = node->split_index,
442  "opname"_a = OpName(node->op),
443  "threshold"_a = node->threshold.int_val);
444  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
445  // According to IEEE 754, the result of comparison [lhs] < infinity
446  // must be identical for all finite [lhs]. Same goes for operator >.
447  result = (common::CompareWithOp(0.0, node->op, node->threshold.float_val)
448  ? "1" : "0");
449  } else { // finite threshold
450  result = fmt::format("data[{split_index}].fvalue {opname} {threshold}",
451  "split_index"_a = node->split_index,
452  "opname"_a = OpName(node->op),
453  "threshold"_a
454  = common::ToStringHighPrecision(node->threshold.float_val));
455  }
456  return result;
457  }
458 
459  inline std::string
460  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
461  std::string result;
462  std::vector<uint64_t> bitmap
463  = common_util::GetCategoricalBitmap(node->left_categories);
464  CHECK_GE(bitmap.size(), 1);
465  bool all_zeros = true;
466  for (uint64_t e : bitmap) {
467  all_zeros &= (e == 0);
468  }
469  if (all_zeros) {
470  result = "0";
471  } else {
472  std::ostringstream oss;
473  oss << "(tmp = (unsigned int)(data[" << node->split_index << "].fvalue) ), "
474  << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
475  << bitmap[0] << "U >> tmp) & 1) )";
476  for (size_t i = 1; i < bitmap.size(); ++i) {
477  oss << " || (tmp >= " << (i * 64)
478  << " && tmp < " << ((i + 1) * 64)
479  << " && (( (uint64_t)" << bitmap[i]
480  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
481  }
482  result = oss.str();
483  }
484  return result;
485  }
486 
487  inline std::string
488  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
489  common::ArrayFormatter formatter(80, 2);
490  for (int fid = 0; fid < num_feature_; ++fid) {
491  formatter << (is_categorical[fid] ? 1 : 0);
492  }
493  return formatter.str();
494  }
495 
496  inline std::string RenderOutputStatement(const OutputNode* node) {
497  std::string output_statement;
498  if (num_output_group_ > 1) {
499  if (node->is_vector) {
500  // multi-class classification with random forest
501  CHECK_EQ(node->vector.size(), static_cast<size_t>(num_output_group_))
502  << "Ill-formed model: leaf vector must be of length [num_output_group]";
503  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
504  output_statement
505  += fmt::format("sum[{group_id}] += (float){output};\n",
506  "group_id"_a = group_id,
507  "output"_a
508  = common::ToStringHighPrecision(node->vector[group_id]));
509  }
510  } else {
511  // multi-class classification with gradient boosted trees
512  output_statement
513  = fmt::format("sum[{group_id}] += (float){output};\n",
514  "group_id"_a = node->tree_id % num_output_group_,
515  "output"_a = common::ToStringHighPrecision(node->scalar));
516  }
517  } else {
518  output_statement
519  = fmt::format("sum += (float){output};\n",
520  "output"_a = common::ToStringHighPrecision(node->scalar));
521  }
522  return output_statement;
523  }
524 };
525 
527 .describe("AST-based compiler that produces C code")
528 .set_body([](const CompilerParam& param) -> Compiler* {
529  return new ASTNativeCompiler(param);
530  });
531 } // namespace compiler
532 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:425
branch annotator class
Definition: annotator.h:17
thin wrapper for tree ensemble model
Definition: tree.h:415
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
parameters for tree compiler
Definition: param.h:18
std::string str()
obtain formatted text containing the rendered array
Definition: common.h:164
template for header
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:33
Interface of compiler that compiles a tree ensemble model.
template for main function
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:124
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:52
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:138
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
Definition: param.h:38
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:26
Utilities for code folding.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:71
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Definition: param.h:44
Function to generate bitmaps for categorical splits.
AST Builder class.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:41
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:34
double tl_float
float type to be used internally
Definition: base.h:17
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:28
int verbose
if >0, produce extra messages
Definition: param.h:36
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:422
Operator
comparison operators
Definition: base.h:23