treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <algorithm>
8 #include <fstream>
9 #include <unordered_map>
10 #include <queue>
11 #include <cmath>
12 #include <treelite/compiler.h>
13 #include <treelite/common.h>
14 #include <treelite/annotator.h>
15 #include <fmt/format.h>
16 #include "./param.h"
17 #include "./pred_transform.h"
18 #include "./ast/builder.h"
19 #include "./native/main_template.h"
25 
26 #if defined(_MSC_VER) || defined(_WIN32)
27 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
28 #else
29 #define DLLEXPORT_KEYWORD ""
30 #endif
31 
32 using namespace fmt::literals;
33 
34 namespace treelite {
35 namespace compiler {
36 
37 DMLC_REGISTRY_FILE_TAG(ast_native);
38 
39 class ASTNativeCompiler : public Compiler {
40  public:
41  explicit ASTNativeCompiler(const CompilerParam& param)
42  : param(param) {
43  if (param.verbose > 0) {
44  LOG(INFO) << "Using ASTNativeCompiler";
45  }
46  if (param.dump_array_as_elf > 0) {
47  LOG(INFO) << "Warning: 'dump_array_as_elf' parameter is not applicable "
48  "for ASTNativeCompiler";
49  }
50  }
51 
52  CompiledModel Compile(const Model& model) override {
53  CompiledModel cm;
54  cm.backend = "native";
55 
56  num_feature_ = model.num_feature;
57  num_output_group_ = model.num_output_group;
58  pred_transform_ = model.param.pred_transform;
59  sigmoid_alpha_ = model.param.sigmoid_alpha;
60  global_bias_ = model.param.global_bias;
61  pred_tranform_func_ = PredTransformFunction("native", model);
62  files_.clear();
63 
64  ASTBuilder builder;
65  builder.BuildAST(model);
66  if (builder.FoldCode(param.code_folding_req)
67  || param.quantize > 0) {
68  // is_categorical[i] : is i-th feature categorical?
69  array_is_categorical_
70  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
71  }
72  if (param.annotate_in != "NULL") {
73  BranchAnnotator annotator;
74  std::unique_ptr<dmlc::Stream> fi(
75  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
76  annotator.Load(fi.get());
77  const auto annotation = annotator.Get();
78  builder.LoadDataCounts(annotation);
79  LOG(INFO) << "Loading node frequencies from `"
80  << param.annotate_in << "'";
81  }
82  builder.Split(param.parallel_comp);
83  if (param.quantize > 0) {
84  builder.QuantizeThresholds();
85  }
86 
87  {
88  const char* destfile = getenv("TREELITE_DUMP_AST");
89  if (destfile) {
90  std::ofstream os(destfile);
91  os << builder.GetDump() << std::endl;
92  }
93  }
94 
95  WalkAST(builder.GetRootNode(), "main.c", 0);
96  if (files_.count("arrays.c") > 0) {
97  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
98  }
99 
100  {
101  /* write recipe.json */
102  std::vector<std::unordered_map<std::string, std::string>> source_list;
103  for (const auto& kv : files_) {
104  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
105  const size_t line_count
106  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
107  source_list.push_back({ {"name",
108  kv.first.substr(0, kv.first.length() - 2)},
109  {"length", std::to_string(line_count)} });
110  }
111  }
112  std::ostringstream oss;
113  auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
114  writer->BeginObject();
115  writer->WriteObjectKeyValue("target", param.native_lib_name);
116  writer->WriteObjectKeyValue("sources", source_list);
117  writer->EndObject();
118  files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
119  }
120  cm.files = std::move(files_);
121  return cm;
122  }
123 
124  private:
125  CompilerParam param;
126  int num_feature_;
127  int num_output_group_;
128  std::string pred_transform_;
129  float sigmoid_alpha_;
130  float global_bias_;
131  std::string pred_tranform_func_;
132  std::string array_is_categorical_;
133  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
134 
135  void WalkAST(const ASTNode* node,
136  const std::string& dest,
137  size_t indent) {
138  const MainNode* t1;
139  const AccumulatorContextNode* t2;
140  const ConditionNode* t3;
141  const OutputNode* t4;
142  const TranslationUnitNode* t5;
143  const QuantizerNode* t6;
144  const CodeFolderNode* t7;
145  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
146  HandleMainNode(t1, dest, indent);
147  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
148  HandleACNode(t2, dest, indent);
149  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
150  HandleCondNode(t3, dest, indent);
151  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
152  HandleOutputNode(t4, dest, indent);
153  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
154  HandleTUNode(t5, dest, indent);
155  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
156  HandleQNode(t6, dest, indent);
157  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
158  HandleCodeFolderNode(t7, dest, indent);
159  } else {
160  LOG(FATAL) << "Unrecognized AST node type";
161  }
162  }
163 
164  // append content to a given buffer, with given level of indentation
165  inline void AppendToBuffer(const std::string& dest,
166  const std::string& content,
167  size_t indent) {
168  files_[dest].content += common::IndentMultiLineString(content, indent);
169  }
170 
171  // prepend content to a given buffer, with given level of indentation
172  inline void PrependToBuffer(const std::string& dest,
173  const std::string& content,
174  size_t indent) {
175  files_[dest].content = common::IndentMultiLineString(content, indent) + files_[dest].content;
176  }
177 
178  void HandleMainNode(const MainNode* node,
179  const std::string& dest,
180  size_t indent) {
181  const char* get_num_output_group_function_signature
182  = "size_t get_num_output_group(void)";
183  const char* get_num_feature_function_signature
184  = "size_t get_num_feature(void)";
185  const char* get_pred_transform_function_signature
186  = "const char* get_pred_transform(void)";
187  const char* get_sigmoid_alpha_function_signature
188  = "float get_sigmoid_alpha(void)";
189  const char* get_global_bias_function_signature
190  = "float get_global_bias(void)";
191  const char* predict_function_signature
192  = (num_output_group_ > 1) ?
193  "size_t predict_multiclass(union Entry* data, int pred_margin, "
194  "float* result)"
195  : "float predict(union Entry* data, int pred_margin)";
196 
197  if (!array_is_categorical_.empty()) {
198  array_is_categorical_
199  = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
200  array_is_categorical_);
201  }
202 
203  AppendToBuffer(dest,
204  fmt::format(native::main_start_template,
205  "array_is_categorical"_a = array_is_categorical_,
206  "get_num_output_group_function_signature"_a
207  = get_num_output_group_function_signature,
208  "get_num_feature_function_signature"_a
209  = get_num_feature_function_signature,
210  "get_pred_transform_function_signature"_a
211  = get_pred_transform_function_signature,
212  "get_sigmoid_alpha_function_signature"_a
213  = get_sigmoid_alpha_function_signature,
214  "get_global_bias_function_signature"_a
215  = get_global_bias_function_signature,
216  "pred_transform_function"_a = pred_tranform_func_,
217  "predict_function_signature"_a = predict_function_signature,
218  "num_output_group"_a = num_output_group_,
219  "num_feature"_a = num_feature_,
220  "pred_transform"_a = pred_transform_,
221  "sigmoid_alpha"_a = sigmoid_alpha_,
222  "global_bias"_a = global_bias_),
223  indent);
224  AppendToBuffer("header.h",
225  fmt::format(native::header_template,
226  "dllexport"_a = DLLEXPORT_KEYWORD,
227  "get_num_output_group_function_signature"_a
228  = get_num_output_group_function_signature,
229  "get_num_feature_function_signature"_a
230  = get_num_feature_function_signature,
231  "get_pred_transform_function_signature"_a
232  = get_pred_transform_function_signature,
233  "get_sigmoid_alpha_function_signature"_a
234  = get_sigmoid_alpha_function_signature,
235  "get_global_bias_function_signature"_a
236  = get_global_bias_function_signature,
237  "predict_function_signature"_a = predict_function_signature,
238  "threshold_type"_a = (param.quantize > 0 ? "int" : "double")),
239  indent);
240 
241  CHECK_EQ(node->children.size(), 1);
242  WalkAST(node->children[0], dest, indent + 2);
243 
244  const std::string optional_average_field
245  = (node->average_result) ? fmt::format(" / {}", node->num_tree)
246  : std::string("");
247  if (num_output_group_ > 1) {
248  AppendToBuffer(dest,
249  fmt::format(native::main_end_multiclass_template,
250  "num_output_group"_a = num_output_group_,
251  "optional_average_field"_a = optional_average_field,
252  "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
253  indent);
254  } else {
255  AppendToBuffer(dest,
256  fmt::format(native::main_end_template,
257  "optional_average_field"_a = optional_average_field,
258  "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
259  indent);
260  }
261  }
262 
263  void HandleACNode(const AccumulatorContextNode* node,
264  const std::string& dest,
265  size_t indent) {
266  if (num_output_group_ > 1) {
267  AppendToBuffer(dest,
268  fmt::format("float sum[{num_output_group}] = {{0.0f}};\n"
269  "unsigned int tmp;\n"
270  "int nid, cond, fid; /* used for folded subtrees */\n",
271  "num_output_group"_a = num_output_group_), indent);
272  } else {
273  AppendToBuffer(dest,
274  "float sum = 0.0f;\n"
275  "unsigned int tmp;\n"
276  "int nid, cond, fid; /* used for folded subtrees */\n", indent);
277  }
278  for (ASTNode* child : node->children) {
279  WalkAST(child, dest, indent);
280  }
281  }
282 
283  void HandleCondNode(const ConditionNode* node,
284  const std::string& dest,
285  size_t indent) {
286  const NumericalConditionNode* t;
287  std::string condition, condition_with_na_check;
288  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
289  /* numerical split */
290  condition = ExtractNumericalCondition(t);
291  const char* condition_with_na_check_template
292  = (node->default_left) ?
293  "!(data[{split_index}].missing != -1) || ({condition})"
294  : " (data[{split_index}].missing != -1) && ({condition})";
295  condition_with_na_check
296  = fmt::format(condition_with_na_check_template,
297  "split_index"_a = node->split_index,
298  "condition"_a = condition);
299  } else { /* categorical split */
300  const CategoricalConditionNode* t2
301  = dynamic_cast<const CategoricalConditionNode*>(node);
302  CHECK(t2);
303  condition_with_na_check = ExtractCategoricalCondition(t2);
304  }
305  if (node->children[0]->data_count && node->children[1]->data_count) {
306  const int left_freq = node->children[0]->data_count.value();
307  const int right_freq = node->children[1]->data_count.value();
308  condition_with_na_check
309  = fmt::format(" {keyword}( {condition} ) ",
310  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
311  "condition"_a = condition_with_na_check);
312  }
313  AppendToBuffer(dest,
314  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
315  CHECK_EQ(node->children.size(), 2);
316  WalkAST(node->children[0], dest, indent + 2);
317  AppendToBuffer(dest, "} else {\n", indent);
318  WalkAST(node->children[1], dest, indent + 2);
319  AppendToBuffer(dest, "}\n", indent);
320  }
321 
322  void HandleOutputNode(const OutputNode* node,
323  const std::string& dest,
324  size_t indent) {
325  AppendToBuffer(dest, RenderOutputStatement(node), indent);
326  CHECK_EQ(node->children.size(), 0);
327  }
328 
329  void HandleTUNode(const TranslationUnitNode* node,
330  const std::string& dest,
331  int indent) {
332  const int unit_id = node->unit_id;
333  const std::string new_file = fmt::format("tu{}.c", unit_id);
334 
335  std::string unit_function_name, unit_function_signature,
336  unit_function_call_signature;
337  if (num_output_group_ > 1) {
338  unit_function_name
339  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
340  unit_function_signature
341  = fmt::format("void {}(union Entry* data, float* result)",
342  unit_function_name);
343  unit_function_call_signature
344  = fmt::format("{}(data, sum);\n", unit_function_name);
345  } else {
346  unit_function_name
347  = fmt::format("predict_margin_unit{}", unit_id);
348  unit_function_signature
349  = fmt::format("float {}(union Entry* data)", unit_function_name);
350  unit_function_call_signature
351  = fmt::format("sum += {}(data);\n", unit_function_name);
352  }
353  AppendToBuffer(dest, unit_function_call_signature, indent);
354  AppendToBuffer(new_file,
355  fmt::format("#include \"header.h\"\n"
356  "{} {{\n", unit_function_signature), 0);
357  CHECK_EQ(node->children.size(), 1);
358  WalkAST(node->children[0], new_file, 2);
359  if (num_output_group_ > 1) {
360  AppendToBuffer(new_file,
361  fmt::format(" for (int i = 0; i < {num_output_group}; ++i) {{\n"
362  " result[i] += sum[i];\n"
363  " }}\n"
364  "}}\n",
365  "num_output_group"_a = num_output_group_), 0);
366  } else {
367  AppendToBuffer(new_file, " return sum;\n}\n", 0);
368  }
369  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
370  }
371 
372  void HandleQNode(const QuantizerNode* node,
373  const std::string& dest,
374  size_t indent) {
375  /* render arrays needed to convert feature values into bin indices */
376  std::string array_threshold, array_th_begin, array_th_len;
377  // threshold[] : list of all thresholds that occur at least once in the
378  // ensemble model. For each feature, an ascending list of unique
379  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
380  // of the threshold[] array stores the threshold list for feature i.
381  size_t total_num_threshold;
382  // to hold total number of (distinct) thresholds
383  {
384  common::ArrayFormatter formatter(80, 2);
385  for (const auto& e : node->cut_pts) {
386  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
387  // cut_pts[i][k] stores the k-th threshold of feature i.
388  for (tl_float v : e) {
389  formatter << v;
390  }
391  }
392  array_threshold = formatter.str();
393  }
394  {
395  common::ArrayFormatter formatter(80, 2);
396  size_t accum = 0; // used to compute cumulative sum over threshold counts
397  for (const auto& e : node->cut_pts) {
398  formatter << accum;
399  accum += e.size(); // e.size() = number of thresholds for each feature
400  }
401  total_num_threshold = accum;
402  array_th_begin = formatter.str();
403  }
404  {
405  common::ArrayFormatter formatter(80, 2);
406  for (const auto& e : node->cut_pts) {
407  formatter << e.size();
408  }
409  array_th_len = formatter.str();
410  }
411  if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
412  PrependToBuffer(dest,
413  fmt::format(native::qnode_template,
414  "total_num_threshold"_a = total_num_threshold), 0);
415  AppendToBuffer(dest,
416  fmt::format(native::quantize_loop_template,
417  "num_feature"_a = num_feature_), indent);
418  }
419  if (!array_threshold.empty()) {
420  PrependToBuffer(dest,
421  fmt::format("static const double threshold[] = {{\n"
422  "{array_threshold}\n"
423  "}};\n", "array_threshold"_a = array_threshold), 0);
424  }
425  if (!array_th_begin.empty()) {
426  PrependToBuffer(dest,
427  fmt::format("static const int th_begin[] = {{\n"
428  "{array_th_begin}\n"
429  "}};\n", "array_th_begin"_a = array_th_begin), 0);
430  }
431  if (!array_th_len.empty()) {
432  PrependToBuffer(dest,
433  fmt::format("static const int th_len[] = {{\n"
434  "{array_th_len}\n"
435  "}};\n", "array_th_len"_a = array_th_len), 0);
436  }
437  CHECK_EQ(node->children.size(), 1);
438  WalkAST(node->children[0], dest, indent);
439  }
440 
441  void HandleCodeFolderNode(const CodeFolderNode* node,
442  const std::string& dest,
443  size_t indent) {
444  CHECK_EQ(node->children.size(), 1);
445  const int node_id = node->children[0]->node_id;
446  const int tree_id = node->children[0]->tree_id;
447 
448  /* render arrays needed for folding subtrees */
449  std::string array_nodes, array_cat_bitmap, array_cat_begin;
450  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
451  const std::string node_array_name
452  = fmt::format("node_tree{}_node{}", tree_id, node_id);
453  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
454  // make all categorical splits in a particular
455  // subtree
456  const std::string cat_bitmap_name
457  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
458  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
459  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
460  // belongs to the i-th (categorical) split
461  const std::string cat_begin_name
462  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
463 
464  std::string output_switch_statement;
465  Operator common_comp_op;
466  common_util::RenderCodeFolderArrays(node, param.quantize, false,
467  "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
468  [this](const OutputNode* node) { return RenderOutputStatement(node); },
469  &array_nodes, &array_cat_bitmap, &array_cat_begin,
470  &output_switch_statement, &common_comp_op);
471  if (!array_nodes.empty()) {
472  AppendToBuffer("header.h",
473  fmt::format("extern const struct Node {node_array_name}[];\n",
474  "node_array_name"_a = node_array_name), 0);
475  AppendToBuffer("arrays.c",
476  fmt::format("const struct Node {node_array_name}[] = {{\n"
477  "{array_nodes}\n"
478  "}};\n",
479  "node_array_name"_a = node_array_name,
480  "array_nodes"_a = array_nodes), 0);
481  }
482 
483  if (!array_cat_bitmap.empty()) {
484  AppendToBuffer("header.h",
485  fmt::format("extern const uint64_t {cat_bitmap_name}[];\n",
486  "cat_bitmap_name"_a = cat_bitmap_name), 0);
487  AppendToBuffer("arrays.c",
488  fmt::format("const uint64_t {cat_bitmap_name}[] = {{\n"
489  "{array_cat_bitmap}\n"
490  "}};\n",
491  "cat_bitmap_name"_a = cat_bitmap_name,
492  "array_cat_bitmap"_a = array_cat_bitmap), 0);
493  }
494 
495  if (!array_cat_begin.empty()) {
496  AppendToBuffer("header.h",
497  fmt::format("extern const size_t {cat_begin_name}[];\n",
498  "cat_begin_name"_a = cat_begin_name), 0);
499  AppendToBuffer("arrays.c",
500  fmt::format("const size_t {cat_begin_name}[] = {{\n"
501  "{array_cat_begin}\n"
502  "}};\n",
503  "cat_begin_name"_a = cat_begin_name,
504  "array_cat_begin"_a = array_cat_begin), 0);
505  }
506 
507  if (array_nodes.empty()) {
508  /* folded code consists of a single leaf node */
509  AppendToBuffer(dest,
510  fmt::format("nid = -1;\n"
511  "{output_switch_statement}\n",
512  "output_switch_statement"_a
513  = output_switch_statement), indent);
514  } else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
515  AppendToBuffer(dest,
516  fmt::format(native::eval_loop_template,
517  "node_array_name"_a = node_array_name,
518  "cat_bitmap_name"_a = cat_bitmap_name,
519  "cat_begin_name"_a = cat_begin_name,
520  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
521  "comp_op"_a = OpName(common_comp_op),
522  "output_switch_statement"_a
523  = output_switch_statement), indent);
524  } else {
525  AppendToBuffer(dest,
526  fmt::format(native::eval_loop_template_without_categorical_feature,
527  "node_array_name"_a = node_array_name,
528  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
529  "comp_op"_a = OpName(common_comp_op),
530  "output_switch_statement"_a
531  = output_switch_statement), indent);
532  }
533  }
534 
535  inline std::string
536  ExtractNumericalCondition(const NumericalConditionNode* node) {
537  std::string result;
538  if (node->quantized) { // quantized threshold
539  result = fmt::format("data[{split_index}].qvalue {opname} {threshold}",
540  "split_index"_a = node->split_index,
541  "opname"_a = OpName(node->op),
542  "threshold"_a = node->threshold.int_val);
543  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
544  // According to IEEE 754, the result of comparison [lhs] < infinity
545  // must be identical for all finite [lhs]. Same goes for operator >.
546  result = (common::CompareWithOp(0.0, node->op, node->threshold.float_val)
547  ? "1" : "0");
548  } else { // finite threshold
549  result = fmt::format("data[{split_index}].fvalue {opname} {threshold}",
550  "split_index"_a = node->split_index,
551  "opname"_a = OpName(node->op),
552  "threshold"_a
553  = common::ToStringHighPrecision(node->threshold.float_val));
554  }
555  return result;
556  }
557 
558  inline std::string
559  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
560  std::string result;
561  std::vector<uint64_t> bitmap
562  = common_util::GetCategoricalBitmap(node->left_categories);
563  CHECK_GE(bitmap.size(), 1);
564  bool all_zeros = true;
565  for (uint64_t e : bitmap) {
566  all_zeros &= (e == 0);
567  }
568  if (all_zeros) {
569  result = "0";
570  } else {
571  std::ostringstream oss;
572  if (node->convert_missing_to_zero) {
573  // All missing values are converted into zeros
574  oss << fmt::format(
575  "((tmp = (data[{0}].missing == -1 ? 0U "
576  ": (unsigned int)(data[{0}].fvalue) )), ", node->split_index);
577  } else {
578  if (node->default_left) {
579  oss << fmt::format(
580  "data[{0}].missing == -1 || ("
581  "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
582  } else {
583  oss << fmt::format(
584  "data[{0}].missing != -1 && ("
585  "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
586  }
587  }
588  oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
589  << bitmap[0] << "U >> tmp) & 1) )";
590  for (size_t i = 1; i < bitmap.size(); ++i) {
591  oss << " || (tmp >= " << (i * 64)
592  << " && tmp < " << ((i + 1) * 64)
593  << " && (( (uint64_t)" << bitmap[i]
594  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
595  }
596  oss << ")";
597  result = oss.str();
598  }
599  return result;
600  }
601 
602  inline std::string
603  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
604  common::ArrayFormatter formatter(80, 2);
605  for (int fid = 0; fid < num_feature_; ++fid) {
606  formatter << (is_categorical[fid] ? 1 : 0);
607  }
608  return formatter.str();
609  }
610 
611  inline std::string RenderOutputStatement(const OutputNode* node) {
612  std::string output_statement;
613  if (num_output_group_ > 1) {
614  if (node->is_vector) {
615  // multi-class classification with random forest
616  CHECK_EQ(node->vector.size(), static_cast<size_t>(num_output_group_))
617  << "Ill-formed model: leaf vector must be of length [num_output_group]";
618  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
619  output_statement
620  += fmt::format("sum[{group_id}] += (float){output};\n",
621  "group_id"_a = group_id,
622  "output"_a
623  = common::ToStringHighPrecision(node->vector[group_id]));
624  }
625  } else {
626  // multi-class classification with gradient boosted trees
627  output_statement
628  = fmt::format("sum[{group_id}] += (float){output};\n",
629  "group_id"_a = node->tree_id % num_output_group_,
630  "output"_a = common::ToStringHighPrecision(node->scalar));
631  }
632  } else {
633  output_statement
634  = fmt::format("sum += (float){output};\n",
635  "output"_a = common::ToStringHighPrecision(node->scalar));
636  }
637  return output_statement;
638  }
639 };
640 
642 .describe("AST-based compiler that produces C code")
643 .set_body([](const CompilerParam& param) -> Compiler* {
644  return new ASTNativeCompiler(param);
645  });
646 } // namespace compiler
647 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:438
branch annotator class
Definition: annotator.h:17
thin wrapper for tree ensemble model
Definition: tree.h:428
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
parameters for tree compiler
Definition: param.h:18
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:392
ModelParam param
extra parameters
Definition: tree.h:443
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:52
std::string str()
obtain formatted text containing the rendered array
Definition: common.h:166
template for header
float global_bias
global bias of the model
Definition: tree.h:399
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:55
Interface of compiler that compiles a tree ensemble model.
template for main function
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:126
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:138
std::string pred_transform
name of prediction transform function
Definition: tree.h:384
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
Definition: param.h:38
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:26
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Definition: param.h:49
Utilities for code folding.
Definition: compiler.h:28
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:93
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Definition: param.h:44
Function to generate bitmaps for categorical splits.
AST Builder class.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:52
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:34
double tl_float
float type to be used internally
Definition: base.h:17
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:28
int verbose
if >0, produce extra messages
Definition: param.h:36
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:435
Operator
comparison operators
Definition: base.h:23