treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <treelite/compiler.h>
8 #include <treelite/common.h>
9 #include <treelite/annotator.h>
10 #include <fmt/format.h>
11 #include <algorithm>
12 #include <unordered_map>
13 #include <queue>
14 #include <cmath>
15 #include "./param.h"
16 #include "./pred_transform.h"
17 #include "./ast/builder.h"
18 #include "./native/main_template.h"
24 
25 #if defined(_MSC_VER) || defined(_WIN32)
26 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
27 #else
28 #define DLLEXPORT_KEYWORD ""
29 #endif
30 
31 using namespace fmt::literals;
32 
33 namespace treelite {
34 namespace compiler {
35 
36 DMLC_REGISTRY_FILE_TAG(ast_native);
37 
38 class ASTNativeCompiler : public Compiler {
39  public:
40  explicit ASTNativeCompiler(const CompilerParam& param)
41  : param(param) {
42  if (param.verbose > 0) {
43  LOG(INFO) << "Using ASTNativeCompiler";
44  }
45  }
46 
47  CompiledModel Compile(const Model& model) override {
48  CompiledModel cm;
49  cm.backend = "native";
50  cm.files["main.c"] = "";
51 
52  num_feature_ = model.num_feature;
53  num_output_group_ = model.num_output_group;
54  pred_tranform_func_ = PredTransformFunction("native", model);
55  files_.clear();
56 
57  ASTBuilder builder;
58  builder.BuildAST(model);
59  if (builder.FoldCode(param.code_folding_req)
60  || param.quantize > 0) {
61  // is_categorical[i] : is i-th feature categorical?
62  array_is_categorical_
63  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
64  }
65  if (param.annotate_in != "NULL") {
66  BranchAnnotator annotator;
67  std::unique_ptr<dmlc::Stream> fi(
68  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
69  annotator.Load(fi.get());
70  const auto annotation = annotator.Get();
71  builder.LoadDataCounts(annotation);
72  LOG(INFO) << "Loading node frequencies from `"
73  << param.annotate_in << "'";
74  }
75  builder.Split(param.parallel_comp);
76  if (param.quantize > 0) {
77  builder.QuantizeThresholds();
78  }
79  WalkAST(builder.GetRootNode(), "main.c", 0);
80  if (files_.count("arrays.c") > 0) {
81  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
82  }
83 
84  {
85  /* write recipe.json */
86  std::vector<std::unordered_map<std::string, std::string>> source_list;
87  for (auto kv : files_) {
88  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
89  const size_t line_count
90  = std::count(kv.second.begin(), kv.second.end(), '\n');
91  source_list.push_back({ {"name",
92  kv.first.substr(0, kv.first.length() - 2)},
93  {"length", std::to_string(line_count)} });
94  }
95  }
96  std::ostringstream oss;
97  auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
98  writer->BeginObject();
99  writer->WriteObjectKeyValue("target", param.native_lib_name);
100  writer->WriteObjectKeyValue("sources", source_list);
101  writer->EndObject();
102  files_["recipe.json"] = oss.str();
103  }
104  cm.files = std::move(files_);
105  return cm;
106  }
107 
108  private:
109  CompilerParam param;
110  int num_feature_;
111  int num_output_group_;
112  std::string pred_tranform_func_;
113  std::string array_is_categorical_;
114  std::unordered_map<std::string, std::string> files_;
115 
116  void WalkAST(const ASTNode* node,
117  const std::string& dest,
118  size_t indent) {
119  const MainNode* t1;
120  const AccumulatorContextNode* t2;
121  const ConditionNode* t3;
122  const OutputNode* t4;
123  const TranslationUnitNode* t5;
124  const QuantizerNode* t6;
125  const CodeFolderNode* t7;
126  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
127  HandleMainNode(t1, dest, indent);
128  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
129  HandleACNode(t2, dest, indent);
130  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
131  HandleCondNode(t3, dest, indent);
132  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
133  HandleOutputNode(t4, dest, indent);
134  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
135  HandleTUNode(t5, dest, indent);
136  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
137  HandleQNode(t6, dest, indent);
138  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
139  HandleCodeFolderNode(t7, dest, indent);
140  } else {
141  LOG(FATAL) << "Unrecognized AST node type";
142  }
143  }
144 
145  // append content to a given buffer, with given level of indentation
146  inline void AppendToBuffer(const std::string& dest,
147  const std::string& content,
148  size_t indent) {
149  files_[dest] += common::IndentMultiLineString(content, indent);
150  }
151 
152  // prepend content to a given buffer, with given level of indentation
153  inline void PrependToBuffer(const std::string& dest,
154  const std::string& content,
155  size_t indent) {
156  files_[dest]
157  = common::IndentMultiLineString(content, indent) + files_[dest];
158  }
159 
160  void HandleMainNode(const MainNode* node,
161  const std::string& dest,
162  size_t indent) {
163  const char* get_num_output_group_function_signature
164  = "size_t get_num_output_group(void)";
165  const char* get_num_feature_function_signature
166  = "size_t get_num_feature(void)";
167  const char* predict_function_signature
168  = (num_output_group_ > 1) ?
169  "size_t predict_multiclass(union Entry* data, int pred_margin, "
170  "float* result)"
171  : "float predict(union Entry* data, int pred_margin)";
172 
173  if (!array_is_categorical_.empty()) {
174  array_is_categorical_
175  = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
176  array_is_categorical_);
177  }
178 
179  AppendToBuffer(dest,
180  fmt::format(native::main_start_template,
181  "array_is_categorical"_a = array_is_categorical_,
182  "get_num_output_group_function_signature"_a
183  = get_num_output_group_function_signature,
184  "get_num_feature_function_signature"_a
185  = get_num_feature_function_signature,
186  "pred_transform_function"_a = pred_tranform_func_,
187  "predict_function_signature"_a = predict_function_signature,
188  "num_output_group"_a = num_output_group_,
189  "num_feature"_a = node->num_feature),
190  indent);
191  AppendToBuffer("header.h",
192  fmt::format(native::header_template,
193  "dllexport"_a = DLLEXPORT_KEYWORD,
194  "get_num_output_group_function_signature"_a
195  = get_num_output_group_function_signature,
196  "get_num_feature_function_signature"_a
197  = get_num_feature_function_signature,
198  "predict_function_signature"_a = predict_function_signature,
199  "threshold_type"_a = (param.quantize > 0 ? "int" : "double")),
200  indent);
201 
202  CHECK_EQ(node->children.size(), 1);
203  WalkAST(node->children[0], dest, indent + 2);
204 
205  const std::string optional_average_field
206  = (node->average_result) ? fmt::format(" / {}", node->num_tree)
207  : std::string("");
208  if (num_output_group_ > 1) {
209  AppendToBuffer(dest,
210  fmt::format(native::main_end_multiclass_template,
211  "num_output_group"_a = num_output_group_,
212  "optional_average_field"_a = optional_average_field,
213  "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
214  indent);
215  } else {
216  AppendToBuffer(dest,
217  fmt::format(native::main_end_template,
218  "optional_average_field"_a = optional_average_field,
219  "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
220  indent);
221  }
222  }
223 
224  void HandleACNode(const AccumulatorContextNode* node,
225  const std::string& dest,
226  size_t indent) {
227  if (num_output_group_ > 1) {
228  AppendToBuffer(dest,
229  fmt::format("float sum[{num_output_group}] = {{0.0f}};\n"
230  "unsigned int tmp;\n"
231  "int nid, cond, fid; /* used for folded subtrees */\n",
232  "num_output_group"_a = num_output_group_), indent);
233  } else {
234  AppendToBuffer(dest,
235  "float sum = 0.0f;\n"
236  "unsigned int tmp;\n"
237  "int nid, cond, fid; /* used for folded subtrees */\n", indent);
238  }
239  for (ASTNode* child : node->children) {
240  WalkAST(child, dest, indent);
241  }
242  }
243 
244  void HandleCondNode(const ConditionNode* node,
245  const std::string& dest,
246  size_t indent) {
247  const NumericalConditionNode* t;
248  std::string condition, condition_with_na_check;
249  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
250  /* numerical split */
251  condition = ExtractNumericalCondition(t);
252  const char* condition_with_na_check_template
253  = (node->default_left) ?
254  "!(data[{split_index}].missing != -1) || ({condition})"
255  : " (data[{split_index}].missing != -1) && ({condition})";
256  condition_with_na_check
257  = fmt::format(condition_with_na_check_template,
258  "split_index"_a = node->split_index,
259  "condition"_a = condition);
260  } else { /* categorical split */
261  const CategoricalConditionNode* t2
262  = dynamic_cast<const CategoricalConditionNode*>(node);
263  CHECK(t2);
264  condition_with_na_check = ExtractCategoricalCondition(t2);
265  }
266  if (node->children[0]->data_count && node->children[1]->data_count) {
267  const int left_freq = node->children[0]->data_count.value();
268  const int right_freq = node->children[1]->data_count.value();
269  condition_with_na_check
270  = fmt::format(" {keyword}( {condition} ) ",
271  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
272  "condition"_a = condition_with_na_check);
273  }
274  AppendToBuffer(dest,
275  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
276  CHECK_EQ(node->children.size(), 2);
277  WalkAST(node->children[0], dest, indent + 2);
278  AppendToBuffer(dest, "} else {\n", indent);
279  WalkAST(node->children[1], dest, indent + 2);
280  AppendToBuffer(dest, "}\n", indent);
281  }
282 
283  void HandleOutputNode(const OutputNode* node,
284  const std::string& dest,
285  size_t indent) {
286  AppendToBuffer(dest, RenderOutputStatement(node), indent);
287  CHECK_EQ(node->children.size(), 0);
288  }
289 
290  void HandleTUNode(const TranslationUnitNode* node,
291  const std::string& dest,
292  int indent) {
293  const int unit_id = node->unit_id;
294  const std::string new_file = fmt::format("tu{}.c", unit_id);
295 
296  std::string unit_function_name, unit_function_signature,
297  unit_function_call_signature;
298  if (num_output_group_ > 1) {
299  unit_function_name
300  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
301  unit_function_signature
302  = fmt::format("void {}(union Entry* data, float* result)",
303  unit_function_name);
304  unit_function_call_signature
305  = fmt::format("{}(data, sum);\n", unit_function_name);
306  } else {
307  unit_function_name
308  = fmt::format("predict_margin_unit{}", unit_id);
309  unit_function_signature
310  = fmt::format("float {}(union Entry* data)", unit_function_name);
311  unit_function_call_signature
312  = fmt::format("sum += {}(data);\n", unit_function_name);
313  }
314  AppendToBuffer(dest, unit_function_call_signature, indent);
315  AppendToBuffer(new_file,
316  fmt::format("#include \"header.h\"\n"
317  "{} {{\n", unit_function_signature), 0);
318  CHECK_EQ(node->children.size(), 1);
319  WalkAST(node->children[0], new_file, 2);
320  if (num_output_group_ > 1) {
321  AppendToBuffer(new_file,
322  fmt::format(" for (int i = 0; i < {num_output_group}; ++i) {{\n"
323  " result[i] += sum[i];\n"
324  " }}\n"
325  "}}\n",
326  "num_output_group"_a = num_output_group_), 0);
327  } else {
328  AppendToBuffer(new_file, " return sum;\n}\n", 0);
329  }
330  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
331  }
332 
333  void HandleQNode(const QuantizerNode* node,
334  const std::string& dest,
335  size_t indent) {
336  /* render arrays needed to convert feature values into bin indices */
337  std::string array_threshold, array_th_begin, array_th_len;
338  // threshold[] : list of all thresholds that occur at least once in the
339  // ensemble model. For each feature, an ascending list of unique
340  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
341  // of the threshold[] array stores the threshold list for feature i.
342  size_t total_num_threshold;
343  // to hold total number of (distinct) thresholds
344  {
345  common::ArrayFormatter formatter(80, 2);
346  for (const auto& e : node->cut_pts) {
347  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
348  // cut_pts[i][k] stores the k-th threshold of feature i.
349  for (tl_float v : e) {
350  formatter << v;
351  }
352  }
353  array_threshold = formatter.str();
354  }
355  {
356  common::ArrayFormatter formatter(80, 2);
357  size_t accum = 0; // used to compute cumulative sum over threshold counts
358  for (const auto& e : node->cut_pts) {
359  formatter << accum;
360  accum += e.size(); // e.size() = number of thresholds for each feature
361  }
362  total_num_threshold = accum;
363  array_th_begin = formatter.str();
364  }
365  {
366  common::ArrayFormatter formatter(80, 2);
367  for (const auto& e : node->cut_pts) {
368  formatter << e.size();
369  }
370  array_th_len = formatter.str();
371  }
372  PrependToBuffer(dest,
373  fmt::format(native::qnode_template,
374  "array_threshold"_a = array_threshold,
375  "array_th_begin"_a = array_th_begin,
376  "array_th_len"_a = array_th_len,
377  "total_num_threshold"_a = total_num_threshold), 0);
378  AppendToBuffer(dest,
379  fmt::format(native::quantize_loop_template,
380  "num_feature"_a = num_feature_), indent);
381  CHECK_EQ(node->children.size(), 1);
382  WalkAST(node->children[0], dest, indent);
383  }
384 
385  void HandleCodeFolderNode(const CodeFolderNode* node,
386  const std::string& dest,
387  size_t indent) {
388  CHECK_EQ(node->children.size(), 1);
389  const int node_id = node->children[0]->node_id;
390  const int tree_id = node->children[0]->tree_id;
391 
392  /* render arrays needed for folding subtrees */
393  std::string array_nodes, array_cat_bitmap, array_cat_begin;
394  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
395  const std::string node_array_name
396  = fmt::format("node_tree{}_node{}", tree_id, node_id);
397  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
398  // make all categorical splits in a particular
399  // subtree
400  const std::string cat_bitmap_name
401  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
402  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
403  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
404  // belongs to the i-th (categorical) split
405  const std::string cat_begin_name
406  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
407 
408  std::string output_switch_statement;
409  Operator common_comp_op;
410  common_util::RenderCodeFolderArrays(node, param.quantize, false,
411  "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
412  [this](const OutputNode* node) { return RenderOutputStatement(node); },
413  &array_nodes, &array_cat_bitmap, &array_cat_begin,
414  &output_switch_statement, &common_comp_op);
415 
416  AppendToBuffer("header.h",
417  fmt::format(native::code_folder_arrays_declaration_template,
418  "node_array_name"_a = node_array_name,
419  "cat_bitmap_name"_a = cat_bitmap_name,
420  "cat_begin_name"_a = cat_begin_name), 0);
421  AppendToBuffer("arrays.c",
422  fmt::format(native::code_folder_arrays_template,
423  "node_array_name"_a = node_array_name,
424  "array_nodes"_a = array_nodes,
425  "cat_bitmap_name"_a = cat_bitmap_name,
426  "array_cat_bitmap"_a = array_cat_bitmap,
427  "cat_begin_name"_a = cat_begin_name,
428  "array_cat_begin"_a = array_cat_begin), 0);
429  if (array_nodes.empty()) {
430  /* folded code consists of a single leaf node */
431  AppendToBuffer(dest,
432  fmt::format("nid = -1;\n"
433  "{output_switch_statement}\n",
434  "output_switch_statement"_a
435  = output_switch_statement), indent);
436  } else {
437  AppendToBuffer(dest,
438  fmt::format(native::eval_loop_template,
439  "node_array_name"_a = node_array_name,
440  "cat_bitmap_name"_a = cat_bitmap_name,
441  "cat_begin_name"_a = cat_begin_name,
442  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
443  "comp_op"_a = OpName(common_comp_op),
444  "output_switch_statement"_a
445  = output_switch_statement), indent);
446  }
447  }
448 
449  inline std::string
450  ExtractNumericalCondition(const NumericalConditionNode* node) {
451  std::string result;
452  if (node->quantized) { // quantized threshold
453  result = fmt::format("data[{split_index}].qvalue {opname} {threshold}",
454  "split_index"_a = node->split_index,
455  "opname"_a = OpName(node->op),
456  "threshold"_a = node->threshold.int_val);
457  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
458  // According to IEEE 754, the result of comparison [lhs] < infinity
459  // must be identical for all finite [lhs]. Same goes for operator >.
460  result = (common::CompareWithOp(0.0, node->op, node->threshold.float_val)
461  ? "1" : "0");
462  } else { // finite threshold
463  result = fmt::format("data[{split_index}].fvalue {opname} {threshold}",
464  "split_index"_a = node->split_index,
465  "opname"_a = OpName(node->op),
466  "threshold"_a
467  = common::ToStringHighPrecision(node->threshold.float_val));
468  }
469  return result;
470  }
471 
472  inline std::string
473  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
474  std::string result;
475  std::vector<uint64_t> bitmap
476  = common_util::GetCategoricalBitmap(node->left_categories);
477  CHECK_GE(bitmap.size(), 1);
478  bool all_zeros = true;
479  for (uint64_t e : bitmap) {
480  all_zeros &= (e == 0);
481  }
482  if (all_zeros) {
483  result = "0";
484  } else {
485  std::ostringstream oss;
486  if (node->convert_missing_to_zero) {
487  // All missing values are converted into zeros
488  oss << fmt::format(
489  "((tmp = (data[{0}].missing == -1 ? 0U "
490  ": (unsigned int)(data[{0}].fvalue) )), ", node->split_index);
491  } else {
492  if (node->default_left) {
493  oss << fmt::format(
494  "data[{0}].missing == -1 || ("
495  "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
496  } else {
497  oss << fmt::format(
498  "data[{0}].missing != -1 && ("
499  "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
500  }
501  }
502  oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
503  << bitmap[0] << "U >> tmp) & 1) )";
504  for (size_t i = 1; i < bitmap.size(); ++i) {
505  oss << " || (tmp >= " << (i * 64)
506  << " && tmp < " << ((i + 1) * 64)
507  << " && (( (uint64_t)" << bitmap[i]
508  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
509  }
510  oss << ")";
511  result = oss.str();
512  }
513  return result;
514  }
515 
516  inline std::string
517  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
518  common::ArrayFormatter formatter(80, 2);
519  for (int fid = 0; fid < num_feature_; ++fid) {
520  formatter << (is_categorical[fid] ? 1 : 0);
521  }
522  return formatter.str();
523  }
524 
525  inline std::string RenderOutputStatement(const OutputNode* node) {
526  std::string output_statement;
527  if (num_output_group_ > 1) {
528  if (node->is_vector) {
529  // multi-class classification with random forest
530  CHECK_EQ(node->vector.size(), static_cast<size_t>(num_output_group_))
531  << "Ill-formed model: leaf vector must be of length [num_output_group]";
532  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
533  output_statement
534  += fmt::format("sum[{group_id}] += (float){output};\n",
535  "group_id"_a = group_id,
536  "output"_a
537  = common::ToStringHighPrecision(node->vector[group_id]));
538  }
539  } else {
540  // multi-class classification with gradient boosted trees
541  output_statement
542  = fmt::format("sum[{group_id}] += (float){output};\n",
543  "group_id"_a = node->tree_id % num_output_group_,
544  "output"_a = common::ToStringHighPrecision(node->scalar));
545  }
546  } else {
547  output_statement
548  = fmt::format("sum += (float){output};\n",
549  "output"_a = common::ToStringHighPrecision(node->scalar));
550  }
551  return output_statement;
552  }
553 };
554 
556 .describe("AST-based compiler that produces C code")
557 .set_body([](const CompilerParam& param) -> Compiler* {
558  return new ASTNativeCompiler(param);
559  });
560 } // namespace compiler
561 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:437
branch annotator class
Definition: annotator.h:17
thin wrapper for tree ensemble model
Definition: tree.h:427
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
parameters for tree compiler
Definition: param.h:18
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:52
std::string str()
obtain formatted text containing the rendered array
Definition: common.h:164
template for header
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:33
Interface of compiler that compiles a tree ensemble model.
template for main function
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:124
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:138
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
Definition: param.h:38
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:26
Utilities for code folding.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:71
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Definition: param.h:44
Function to generate bitmaps for categorical splits.
AST Builder class.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:47
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:34
double tl_float
float type to be used internally
Definition: base.h:17
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:28
int verbose
if >0, produce extra messages
Definition: param.h:36
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:434
Operator
comparison operators
Definition: base.h:23