Treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <treelite/compiler.h>
9 #include <treelite/annotator.h>
10 #include <fmt/format.h>
11 #include <algorithm>
12 #include <fstream>
13 #include <unordered_map>
14 #include <queue>
15 #include <cmath>
16 #include "./pred_transform.h"
17 #include "./ast/builder.h"
18 #include "./native/main_template.h"
22 #include "./common/format_util.h"
25 
26 #if defined(_MSC_VER) || defined(_WIN32)
27 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
28 #else
29 #define DLLEXPORT_KEYWORD ""
30 #endif
31 
32 using namespace fmt::literals;
33 
34 namespace treelite {
35 namespace compiler {
36 
37 DMLC_REGISTRY_FILE_TAG(ast_native);
38 
39 class ASTNativeCompiler : public Compiler {
40  public:
41  explicit ASTNativeCompiler(const CompilerParam& param)
42  : param(param) {
43  if (param.verbose > 0) {
44  LOG(INFO) << "Using ASTNativeCompiler";
45  }
46  if (param.dump_array_as_elf > 0) {
47  LOG(INFO) << "Warning: 'dump_array_as_elf' parameter is not applicable "
48  "for ASTNativeCompiler";
49  }
50  }
51 
52  CompiledModel Compile(const Model& model) override {
53  CompiledModel cm;
54  cm.backend = "native";
55 
56  num_feature_ = model.num_feature;
57  num_output_group_ = model.num_output_group;
58  pred_transform_ = model.param.pred_transform;
59  sigmoid_alpha_ = model.param.sigmoid_alpha;
60  global_bias_ = model.param.global_bias;
61  pred_tranform_func_ = PredTransformFunction("native", model);
62  files_.clear();
63 
64  ASTBuilder builder;
65  builder.BuildAST(model);
66  if (builder.FoldCode(param.code_folding_req)
67  || param.quantize > 0) {
68  // is_categorical[i] : is i-th feature categorical?
69  array_is_categorical_
70  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
71  }
72  if (param.annotate_in != "NULL") {
73  BranchAnnotator annotator;
74  std::unique_ptr<dmlc::Stream> fi(
75  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
76  annotator.Load(fi.get());
77  const auto annotation = annotator.Get();
78  builder.LoadDataCounts(annotation);
79  LOG(INFO) << "Loading node frequencies from `"
80  << param.annotate_in << "'";
81  }
82  builder.Split(param.parallel_comp);
83  if (param.quantize > 0) {
84  builder.QuantizeThresholds();
85  }
86 
87  {
88  const char* destfile = getenv("TREELITE_DUMP_AST");
89  if (destfile) {
90  std::ofstream os(destfile);
91  os << builder.GetDump() << std::endl;
92  }
93  }
94 
95  WalkAST(builder.GetRootNode(), "main.c", 0);
96  if (files_.count("arrays.c") > 0) {
97  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
98  }
99 
100  {
101  /* write recipe.json */
102  std::vector<std::unordered_map<std::string, std::string>> source_list;
103  for (const auto& kv : files_) {
104  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
105  const size_t line_count
106  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
107  source_list.push_back({ {"name",
108  kv.first.substr(0, kv.first.length() - 2)},
109  {"length", std::to_string(line_count)} });
110  }
111  }
112  std::ostringstream oss;
113  std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&oss));
114  writer->BeginObject();
115  writer->WriteObjectKeyValue("target", param.native_lib_name);
116  writer->WriteObjectKeyValue("sources", source_list);
117  writer->EndObject();
118  files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
119  }
120  cm.files = std::move(files_);
121  return cm;
122  }
123 
124  private:
125  CompilerParam param;
126  int num_feature_;
127  int num_output_group_;
128  std::string pred_transform_;
129  float sigmoid_alpha_;
130  float global_bias_;
131  std::string pred_tranform_func_;
132  std::string array_is_categorical_;
133  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
134 
135  void WalkAST(const ASTNode* node,
136  const std::string& dest,
137  size_t indent) {
138  const MainNode* t1;
139  const AccumulatorContextNode* t2;
140  const ConditionNode* t3;
141  const OutputNode* t4;
142  const TranslationUnitNode* t5;
143  const QuantizerNode* t6;
144  const CodeFolderNode* t7;
145  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
146  HandleMainNode(t1, dest, indent);
147  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
148  HandleACNode(t2, dest, indent);
149  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
150  HandleCondNode(t3, dest, indent);
151  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
152  HandleOutputNode(t4, dest, indent);
153  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
154  HandleTUNode(t5, dest, indent);
155  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
156  HandleQNode(t6, dest, indent);
157  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
158  HandleCodeFolderNode(t7, dest, indent);
159  } else {
160  LOG(FATAL) << "Unrecognized AST node type";
161  }
162  }
163 
164  // append content to a given buffer, with given level of indentation
165  inline void AppendToBuffer(const std::string& dest,
166  const std::string& content,
167  size_t indent) {
168  files_[dest].content += common_util::IndentMultiLineString(content, indent);
169  }
170 
171  // prepend content to a given buffer, with given level of indentation
172  inline void PrependToBuffer(const std::string& dest,
173  const std::string& content,
174  size_t indent) {
175  files_[dest].content
176  = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
177  }
178 
179  void HandleMainNode(const MainNode* node,
180  const std::string& dest,
181  size_t indent) {
182  const char* get_num_output_group_function_signature
183  = "size_t get_num_output_group(void)";
184  const char* get_num_feature_function_signature
185  = "size_t get_num_feature(void)";
186  const char* get_pred_transform_function_signature
187  = "const char* get_pred_transform(void)";
188  const char* get_sigmoid_alpha_function_signature
189  = "float get_sigmoid_alpha(void)";
190  const char* get_global_bias_function_signature
191  = "float get_global_bias(void)";
192  const char* predict_function_signature
193  = (num_output_group_ > 1) ?
194  "size_t predict_multiclass(union Entry* data, int pred_margin, "
195  "float* result)"
196  : "float predict(union Entry* data, int pred_margin)";
197 
198  if (!array_is_categorical_.empty()) {
199  array_is_categorical_
200  = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
201  array_is_categorical_);
202  }
203 
204  AppendToBuffer(dest,
205  fmt::format(native::main_start_template,
206  "array_is_categorical"_a = array_is_categorical_,
207  "get_num_output_group_function_signature"_a
208  = get_num_output_group_function_signature,
209  "get_num_feature_function_signature"_a
210  = get_num_feature_function_signature,
211  "get_pred_transform_function_signature"_a
212  = get_pred_transform_function_signature,
213  "get_sigmoid_alpha_function_signature"_a
214  = get_sigmoid_alpha_function_signature,
215  "get_global_bias_function_signature"_a
216  = get_global_bias_function_signature,
217  "pred_transform_function"_a = pred_tranform_func_,
218  "predict_function_signature"_a = predict_function_signature,
219  "num_output_group"_a = num_output_group_,
220  "num_feature"_a = num_feature_,
221  "pred_transform"_a = pred_transform_,
222  "sigmoid_alpha"_a = sigmoid_alpha_,
223  "global_bias"_a = global_bias_),
224  indent);
225  AppendToBuffer("header.h",
226  fmt::format(native::header_template,
227  "dllexport"_a = DLLEXPORT_KEYWORD,
228  "get_num_output_group_function_signature"_a
229  = get_num_output_group_function_signature,
230  "get_num_feature_function_signature"_a
231  = get_num_feature_function_signature,
232  "get_pred_transform_function_signature"_a
233  = get_pred_transform_function_signature,
234  "get_sigmoid_alpha_function_signature"_a
235  = get_sigmoid_alpha_function_signature,
236  "get_global_bias_function_signature"_a
237  = get_global_bias_function_signature,
238  "predict_function_signature"_a = predict_function_signature,
239  "threshold_type"_a = (param.quantize > 0 ? "int" : "float")),
240  indent);
241 
242  CHECK_EQ(node->children.size(), 1);
243  WalkAST(node->children[0], dest, indent + 2);
244 
245  const std::string optional_average_field
246  = (node->average_result) ? fmt::format(" / {}", node->num_tree)
247  : std::string("");
248  if (num_output_group_ > 1) {
249  AppendToBuffer(dest,
250  fmt::format(native::main_end_multiclass_template,
251  "num_output_group"_a = num_output_group_,
252  "optional_average_field"_a = optional_average_field,
253  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias)),
254  indent);
255  } else {
256  AppendToBuffer(dest,
257  fmt::format(native::main_end_template,
258  "optional_average_field"_a = optional_average_field,
259  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias)),
260  indent);
261  }
262  }
263 
264  void HandleACNode(const AccumulatorContextNode* node,
265  const std::string& dest,
266  size_t indent) {
267  if (num_output_group_ > 1) {
268  AppendToBuffer(dest,
269  fmt::format("float sum[{num_output_group}] = {{0.0f}};\n"
270  "unsigned int tmp;\n"
271  "int nid, cond, fid; /* used for folded subtrees */\n",
272  "num_output_group"_a = num_output_group_), indent);
273  } else {
274  AppendToBuffer(dest,
275  "float sum = 0.0f;\n"
276  "unsigned int tmp;\n"
277  "int nid, cond, fid; /* used for folded subtrees */\n", indent);
278  }
279  for (ASTNode* child : node->children) {
280  WalkAST(child, dest, indent);
281  }
282  }
283 
284  void HandleCondNode(const ConditionNode* node,
285  const std::string& dest,
286  size_t indent) {
287  const NumericalConditionNode* t;
288  std::string condition, condition_with_na_check;
289  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
290  /* numerical split */
291  condition = ExtractNumericalCondition(t);
292  const char* condition_with_na_check_template
293  = (node->default_left) ?
294  "!(data[{split_index}].missing != -1) || ({condition})"
295  : " (data[{split_index}].missing != -1) && ({condition})";
296  condition_with_na_check
297  = fmt::format(condition_with_na_check_template,
298  "split_index"_a = node->split_index,
299  "condition"_a = condition);
300  } else { /* categorical split */
301  const CategoricalConditionNode* t2
302  = dynamic_cast<const CategoricalConditionNode*>(node);
303  CHECK(t2);
304  condition_with_na_check = ExtractCategoricalCondition(t2);
305  }
306  if (node->children[0]->data_count && node->children[1]->data_count) {
307  const int left_freq = node->children[0]->data_count.value();
308  const int right_freq = node->children[1]->data_count.value();
309  condition_with_na_check
310  = fmt::format(" {keyword}( {condition} ) ",
311  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
312  "condition"_a = condition_with_na_check);
313  }
314  AppendToBuffer(dest,
315  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
316  CHECK_EQ(node->children.size(), 2);
317  WalkAST(node->children[0], dest, indent + 2);
318  AppendToBuffer(dest, "} else {\n", indent);
319  WalkAST(node->children[1], dest, indent + 2);
320  AppendToBuffer(dest, "}\n", indent);
321  }
322 
323  void HandleOutputNode(const OutputNode* node,
324  const std::string& dest,
325  size_t indent) {
326  AppendToBuffer(dest, RenderOutputStatement(node), indent);
327  CHECK_EQ(node->children.size(), 0);
328  }
329 
330  void HandleTUNode(const TranslationUnitNode* node,
331  const std::string& dest,
332  int indent) {
333  const int unit_id = node->unit_id;
334  const std::string new_file = fmt::format("tu{}.c", unit_id);
335 
336  std::string unit_function_name, unit_function_signature,
337  unit_function_call_signature;
338  if (num_output_group_ > 1) {
339  unit_function_name
340  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
341  unit_function_signature
342  = fmt::format("void {}(union Entry* data, float* result)",
343  unit_function_name);
344  unit_function_call_signature
345  = fmt::format("{}(data, sum);\n", unit_function_name);
346  } else {
347  unit_function_name
348  = fmt::format("predict_margin_unit{}", unit_id);
349  unit_function_signature
350  = fmt::format("float {}(union Entry* data)", unit_function_name);
351  unit_function_call_signature
352  = fmt::format("sum += {}(data);\n", unit_function_name);
353  }
354  AppendToBuffer(dest, unit_function_call_signature, indent);
355  AppendToBuffer(new_file,
356  fmt::format("#include \"header.h\"\n"
357  "{} {{\n", unit_function_signature), 0);
358  CHECK_EQ(node->children.size(), 1);
359  WalkAST(node->children[0], new_file, 2);
360  if (num_output_group_ > 1) {
361  AppendToBuffer(new_file,
362  fmt::format(" for (int i = 0; i < {num_output_group}; ++i) {{\n"
363  " result[i] += sum[i];\n"
364  " }}\n"
365  "}}\n",
366  "num_output_group"_a = num_output_group_), 0);
367  } else {
368  AppendToBuffer(new_file, " return sum;\n}\n", 0);
369  }
370  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
371  }
372 
373  void HandleQNode(const QuantizerNode* node,
374  const std::string& dest,
375  size_t indent) {
376  /* render arrays needed to convert feature values into bin indices */
377  std::string array_threshold, array_th_begin, array_th_len;
378  // threshold[] : list of all thresholds that occur at least once in the
379  // ensemble model. For each feature, an ascending list of unique
380  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
381  // of the threshold[] array stores the threshold list for feature i.
382  size_t total_num_threshold;
383  // to hold total number of (distinct) thresholds
384  {
385  common_util::ArrayFormatter formatter(80, 2);
386  for (const auto& e : node->cut_pts) {
387  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
388  // cut_pts[i][k] stores the k-th threshold of feature i.
389  for (tl_float v : e) {
390  formatter << v;
391  }
392  }
393  array_threshold = formatter.str();
394  }
395  {
396  common_util::ArrayFormatter formatter(80, 2);
397  size_t accum = 0; // used to compute cumulative sum over threshold counts
398  for (const auto& e : node->cut_pts) {
399  formatter << accum;
400  accum += e.size(); // e.size() = number of thresholds for each feature
401  }
402  total_num_threshold = accum;
403  array_th_begin = formatter.str();
404  }
405  {
406  common_util::ArrayFormatter formatter(80, 2);
407  for (const auto& e : node->cut_pts) {
408  formatter << e.size();
409  }
410  array_th_len = formatter.str();
411  }
412  if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
413  PrependToBuffer(dest,
414  fmt::format(native::qnode_template,
415  "total_num_threshold"_a = total_num_threshold), 0);
416  AppendToBuffer(dest,
417  fmt::format(native::quantize_loop_template,
418  "num_feature"_a = num_feature_), indent);
419  }
420  if (!array_threshold.empty()) {
421  PrependToBuffer(dest,
422  fmt::format("static const float threshold[] = {{\n"
423  "{array_threshold}\n"
424  "}};\n", "array_threshold"_a = array_threshold), 0);
425  }
426  if (!array_th_begin.empty()) {
427  PrependToBuffer(dest,
428  fmt::format("static const int th_begin[] = {{\n"
429  "{array_th_begin}\n"
430  "}};\n", "array_th_begin"_a = array_th_begin), 0);
431  }
432  if (!array_th_len.empty()) {
433  PrependToBuffer(dest,
434  fmt::format("static const int th_len[] = {{\n"
435  "{array_th_len}\n"
436  "}};\n", "array_th_len"_a = array_th_len), 0);
437  }
438  CHECK_EQ(node->children.size(), 1);
439  WalkAST(node->children[0], dest, indent);
440  }
441 
442  void HandleCodeFolderNode(const CodeFolderNode* node,
443  const std::string& dest,
444  size_t indent) {
445  CHECK_EQ(node->children.size(), 1);
446  const int node_id = node->children[0]->node_id;
447  const int tree_id = node->children[0]->tree_id;
448 
449  /* render arrays needed for folding subtrees */
450  std::string array_nodes, array_cat_bitmap, array_cat_begin;
451  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
452  const std::string node_array_name
453  = fmt::format("node_tree{}_node{}", tree_id, node_id);
454  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
455  // make all categorical splits in a particular
456  // subtree
457  const std::string cat_bitmap_name
458  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
459  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
460  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
461  // belongs to the i-th (categorical) split
462  const std::string cat_begin_name
463  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
464 
465  std::string output_switch_statement;
466  Operator common_comp_op;
467  common_util::RenderCodeFolderArrays(node, param.quantize, false,
468  "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
469  [this](const OutputNode* node) { return RenderOutputStatement(node); },
470  &array_nodes, &array_cat_bitmap, &array_cat_begin,
471  &output_switch_statement, &common_comp_op);
472  if (!array_nodes.empty()) {
473  AppendToBuffer("header.h",
474  fmt::format("extern const struct Node {node_array_name}[];\n",
475  "node_array_name"_a = node_array_name), 0);
476  AppendToBuffer("arrays.c",
477  fmt::format("const struct Node {node_array_name}[] = {{\n"
478  "{array_nodes}\n"
479  "}};\n",
480  "node_array_name"_a = node_array_name,
481  "array_nodes"_a = array_nodes), 0);
482  }
483 
484  if (!array_cat_bitmap.empty()) {
485  AppendToBuffer("header.h",
486  fmt::format("extern const uint64_t {cat_bitmap_name}[];\n",
487  "cat_bitmap_name"_a = cat_bitmap_name), 0);
488  AppendToBuffer("arrays.c",
489  fmt::format("const uint64_t {cat_bitmap_name}[] = {{\n"
490  "{array_cat_bitmap}\n"
491  "}};\n",
492  "cat_bitmap_name"_a = cat_bitmap_name,
493  "array_cat_bitmap"_a = array_cat_bitmap), 0);
494  }
495 
496  if (!array_cat_begin.empty()) {
497  AppendToBuffer("header.h",
498  fmt::format("extern const size_t {cat_begin_name}[];\n",
499  "cat_begin_name"_a = cat_begin_name), 0);
500  AppendToBuffer("arrays.c",
501  fmt::format("const size_t {cat_begin_name}[] = {{\n"
502  "{array_cat_begin}\n"
503  "}};\n",
504  "cat_begin_name"_a = cat_begin_name,
505  "array_cat_begin"_a = array_cat_begin), 0);
506  }
507 
508  if (array_nodes.empty()) {
509  /* folded code consists of a single leaf node */
510  AppendToBuffer(dest,
511  fmt::format("nid = -1;\n"
512  "{output_switch_statement}\n",
513  "output_switch_statement"_a
514  = output_switch_statement), indent);
515  } else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
516  AppendToBuffer(dest,
517  fmt::format(native::eval_loop_template,
518  "node_array_name"_a = node_array_name,
519  "cat_bitmap_name"_a = cat_bitmap_name,
520  "cat_begin_name"_a = cat_begin_name,
521  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
522  "comp_op"_a = OpName(common_comp_op),
523  "output_switch_statement"_a
524  = output_switch_statement), indent);
525  } else {
526  AppendToBuffer(dest,
527  fmt::format(native::eval_loop_template_without_categorical_feature,
528  "node_array_name"_a = node_array_name,
529  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
530  "comp_op"_a = OpName(common_comp_op),
531  "output_switch_statement"_a
532  = output_switch_statement), indent);
533  }
534  }
535 
536  inline std::string
537  ExtractNumericalCondition(const NumericalConditionNode* node) {
538  std::string result;
539  if (node->quantized) { // quantized threshold
540  result = fmt::format("data[{split_index}].qvalue {opname} {threshold}",
541  "split_index"_a = node->split_index,
542  "opname"_a = OpName(node->op),
543  "threshold"_a = node->threshold.int_val);
544  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
545  // According to IEEE 754, the result of comparison [lhs] < infinity
546  // must be identical for all finite [lhs]. Same goes for operator >.
547  result = (CompareWithOp(0.0, node->op, node->threshold.float_val) ? "1" : "0");
548  } else { // finite threshold
549  result = fmt::format("data[{split_index}].fvalue {opname} (float){threshold}",
550  "split_index"_a = node->split_index,
551  "opname"_a = OpName(node->op),
552  "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
553  }
554  return result;
555  }
556 
557  inline std::string
558  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
559  std::string result;
560  std::vector<uint64_t> bitmap
561  = common_util::GetCategoricalBitmap(node->left_categories);
562  CHECK_GE(bitmap.size(), 1);
563  bool all_zeros = true;
564  for (uint64_t e : bitmap) {
565  all_zeros &= (e == 0);
566  }
567  if (all_zeros) {
568  result = "0";
569  } else {
570  std::ostringstream oss;
571  if (node->convert_missing_to_zero) {
572  // All missing values are converted into zeros
573  oss << fmt::format(
574  "((tmp = (data[{0}].missing == -1 ? 0U "
575  ": (unsigned int)(data[{0}].fvalue) )), ", node->split_index);
576  } else {
577  if (node->default_left) {
578  oss << fmt::format(
579  "data[{0}].missing == -1 || ("
580  "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
581  } else {
582  oss << fmt::format(
583  "data[{0}].missing != -1 && ("
584  "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
585  }
586  }
587  oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
588  << bitmap[0] << "U >> tmp) & 1) )";
589  for (size_t i = 1; i < bitmap.size(); ++i) {
590  oss << " || (tmp >= " << (i * 64)
591  << " && tmp < " << ((i + 1) * 64)
592  << " && (( (uint64_t)" << bitmap[i]
593  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
594  }
595  oss << ")";
596  result = oss.str();
597  }
598  return result;
599  }
600 
601  inline std::string
602  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
603  common_util::ArrayFormatter formatter(80, 2);
604  for (int fid = 0; fid < num_feature_; ++fid) {
605  formatter << (is_categorical[fid] ? 1 : 0);
606  }
607  return formatter.str();
608  }
609 
610  inline std::string RenderOutputStatement(const OutputNode* node) {
611  std::string output_statement;
612  if (num_output_group_ > 1) {
613  if (node->is_vector) {
614  // multi-class classification with random forest
615  CHECK_EQ(node->vector.size(), static_cast<size_t>(num_output_group_))
616  << "Ill-formed model: leaf vector must be of length [num_output_group]";
617  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
618  output_statement
619  += fmt::format("sum[{group_id}] += (float){output};\n",
620  "group_id"_a = group_id,
621  "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]));
622  }
623  } else {
624  // multi-class classification with gradient boosted trees
625  output_statement
626  = fmt::format("sum[{group_id}] += (float){output};\n",
627  "group_id"_a = node->tree_id % num_output_group_,
628  "output"_a = common_util::ToStringHighPrecision(node->scalar));
629  }
630  } else {
631  output_statement
632  = fmt::format("sum += (float){output};\n",
633  "output"_a = common_util::ToStringHighPrecision(node->scalar));
634  }
635  return output_statement;
636  }
637 };
638 
640 .describe("AST-based compiler that produces C code")
641 .set_body([](const CompilerParam& param) -> Compiler* {
642  return new ASTNativeCompiler(param);
643  });
644 } // namespace compiler
645 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:419
Parameters for tree compiler.
branch annotator class
Definition: annotator.h:17
bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:59
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:40
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:376
ModelParam param
extra parameters
Definition: tree.h:424
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:52
template for header
float global_bias
global bias of the model
Definition: tree.h:383
interface of compiler
Definition: compiler.h:54
Interface of compiler that compiles a tree ensemble model.
template for main function
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:135
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Utilities for code folding.
Definition: compiler.h:27
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:92
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Function to generate bitmaps for categorical splits.
AST Builder class.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:52
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
std::string str()
obtain formatted text containing the rendered array
Definition: format_util.h:99
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:416
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:368
Operator
comparison operators
Definition: base.h:24