Treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <treelite/compiler.h>
9 #include <treelite/annotator.h>
10 #include <fmt/format.h>
11 #include <algorithm>
12 #include <fstream>
13 #include <unordered_map>
14 #include <queue>
15 #include <cmath>
16 #include "./pred_transform.h"
17 #include "./ast/builder.h"
18 #include "./native/main_template.h"
23 #include "./common/format_util.h"
26 
27 #if defined(_MSC_VER) || defined(_WIN32)
28 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
29 #else
30 #define DLLEXPORT_KEYWORD ""
31 #endif
32 
33 using namespace fmt::literals;
34 
35 namespace treelite {
36 namespace compiler {
37 
38 DMLC_REGISTRY_FILE_TAG(ast_native);
39 
40 class ASTNativeCompiler : public Compiler {
41  public:
42  explicit ASTNativeCompiler(const CompilerParam& param)
43  : param(param) {
44  if (param.verbose > 0) {
45  LOG(INFO) << "Using ASTNativeCompiler";
46  }
47  if (param.dump_array_as_elf > 0) {
48  LOG(INFO) << "Warning: 'dump_array_as_elf' parameter is not applicable "
49  "for ASTNativeCompiler";
50  }
51  }
52 
53  template <typename ThresholdType, typename LeafOutputType>
55  CompiledModel cm;
56  cm.backend = "native";
57 
58  CHECK(model.task_type != TaskType::kMultiClfCategLeaf)
59  << "Model task type unsupported by ASTNativeCompiler";
60  CHECK(model.task_param.output_type == TaskParameter::OutputType::kFloat)
61  << "ASTNativeCompiler only supports models with float output";
62 
63  num_feature_ = model.num_feature;
64  task_type_ = model.task_type;
65  task_param_ = model.task_param;
66  pred_transform_ = model.param.pred_transform;
67  sigmoid_alpha_ = model.param.sigmoid_alpha;
68  global_bias_ = model.param.global_bias;
69  files_.clear();
70 
72  builder.BuildAST(model);
73  if (builder.FoldCode(param.code_folding_req) || param.quantize > 0) {
74  // is_categorical[i] : is i-th feature categorical?
75  array_is_categorical_
76  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
77  }
78  if (param.annotate_in != "NULL") {
79  BranchAnnotator annotator;
80  std::unique_ptr<dmlc::Stream> fi(
81  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
82  annotator.Load(fi.get());
83  const auto annotation = annotator.Get();
84  builder.LoadDataCounts(annotation);
85  LOG(INFO) << "Loading node frequencies from `"
86  << param.annotate_in << "'";
87  }
88  builder.Split(param.parallel_comp);
89  if (param.quantize > 0) {
90  builder.QuantizeThresholds();
91  }
92 
93  {
94  const char* destfile = getenv("TREELITE_DUMP_AST");
95  if (destfile) {
96  std::ofstream os(destfile);
97  os << builder.GetDump() << std::endl;
98  }
99  }
100 
101  WalkAST<ThresholdType, LeafOutputType>(builder.GetRootNode(), "main.c", 0);
102  if (files_.count("arrays.c") > 0) {
103  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
104  }
105 
106  {
107  /* write recipe.json */
108  std::vector<std::unordered_map<std::string, std::string>> source_list;
109  for (const auto& kv : files_) {
110  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
111  const size_t line_count
112  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
113  source_list.push_back({ {"name",
114  kv.first.substr(0, kv.first.length() - 2)},
115  {"length", std::to_string(line_count)} });
116  }
117  }
118  std::ostringstream oss;
119  std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&oss));
120  writer->BeginObject();
121  writer->WriteObjectKeyValue("target", param.native_lib_name);
122  writer->WriteObjectKeyValue("sources", source_list);
123  writer->EndObject();
124  files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
125  }
126  cm.files = std::move(files_);
127  return cm;
128  }
129 
130  CompiledModel Compile(const Model& model) override {
131  CHECK(model.GetLeafOutputType() != TypeInfo::kUInt32)
132  << "Integer leaf outputs not yet supported";
133  this->pred_tranform_func_ = PredTransformFunction("native", model);
134  return model.Dispatch([this](const auto& model_handle) {
135  return this->CompileImpl(model_handle);
136  });
137  }
138 
139  private:
140  CompilerParam param;
141  int num_feature_;
142  TaskType task_type_;
143  TaskParameter task_param_;
144  std::string pred_transform_;
145  float sigmoid_alpha_;
146  float global_bias_;
147  std::string pred_tranform_func_;
148  std::string array_is_categorical_;
149  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
150 
151  template <typename ThresholdType, typename LeafOutputType>
152  void WalkAST(const ASTNode* node,
153  const std::string& dest,
154  size_t indent) {
155  const MainNode* t1;
156  const AccumulatorContextNode* t2;
157  const ConditionNode* t3;
158  const OutputNode<LeafOutputType>* t4;
159  const TranslationUnitNode* t5;
161  const CodeFolderNode* t7;
162  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
163  HandleMainNode<ThresholdType, LeafOutputType>(t1, dest, indent);
164  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
165  HandleACNode<ThresholdType, LeafOutputType>(t2, dest, indent);
166  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
167  HandleCondNode<ThresholdType, LeafOutputType>(t3, dest, indent);
168  } else if ( (t4 = dynamic_cast<const OutputNode<LeafOutputType>*>(node)) ) {
169  HandleOutputNode<ThresholdType, LeafOutputType>(t4, dest, indent);
170  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
171  HandleTUNode<ThresholdType, LeafOutputType>(t5, dest, indent);
172  } else if ( (t6 = dynamic_cast<const QuantizerNode<ThresholdType>*>(node)) ) {
173  HandleQNode<ThresholdType, LeafOutputType>(t6, dest, indent);
174  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
175  HandleCodeFolderNode<ThresholdType, LeafOutputType>(t7, dest, indent);
176  } else {
177  LOG(FATAL) << "Unrecognized AST node type";
178  }
179  }
180 
181  // append content to a given buffer, with given level of indentation
182  inline void AppendToBuffer(const std::string& dest,
183  const std::string& content,
184  size_t indent) {
185  files_[dest].content += common_util::IndentMultiLineString(content, indent);
186  }
187 
188  // prepend content to a given buffer, with given level of indentation
189  inline void PrependToBuffer(const std::string& dest,
190  const std::string& content,
191  size_t indent) {
192  files_[dest].content
193  = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
194  }
195 
196  template <typename ThresholdType, typename LeafOutputType>
197  void HandleMainNode(const MainNode* node,
198  const std::string& dest,
199  size_t indent) {
200  const std::string threshold_type
201  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
202  const std::string leaf_output_type
203  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
204  const std::string predict_function_signature
205  = (task_param_.num_class > 1) ?
206  fmt::format("size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)",
207  leaf_output_type)
208  : fmt::format("{} predict(union Entry* data, int pred_margin)",
209  leaf_output_type);
210 
211  if (!array_is_categorical_.empty()) {
212  array_is_categorical_
213  = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
214  array_is_categorical_);
215  }
216 
217  const std::string query_functions_definition
218  = fmt::format(native::query_functions_definition_template,
219  "num_class"_a = task_param_.num_class,
220  "num_feature"_a = num_feature_,
221  "pred_transform"_a = pred_transform_,
222  "sigmoid_alpha"_a = sigmoid_alpha_,
223  "global_bias"_a = global_bias_,
224  "threshold_type_str"_a = TypeInfoToString(TypeToInfo<ThresholdType>()),
225  "leaf_output_type_str"_a = TypeInfoToString(TypeToInfo<LeafOutputType>()));
226 
227  AppendToBuffer(dest,
228  fmt::format(native::main_start_template,
229  "array_is_categorical"_a = array_is_categorical_,
230  "query_functions_definition"_a = query_functions_definition,
231  "pred_transform_function"_a = pred_tranform_func_,
232  "predict_function_signature"_a = predict_function_signature),
233  indent);
234  const std::string query_functions_prototype
235  = fmt::format(native::query_functions_prototype_template,
236  "dllexport"_a = DLLEXPORT_KEYWORD);
237  AppendToBuffer("header.h",
238  fmt::format(native::header_template,
239  "dllexport"_a = DLLEXPORT_KEYWORD,
240  "predict_function_signature"_a = predict_function_signature,
241  "query_functions_prototype"_a = query_functions_prototype,
242  "threshold_type"_a = threshold_type,
243  "threshold_type_Node"_a = (param.quantize > 0 ? std::string("int") : threshold_type)),
244  indent);
245 
246  CHECK_EQ(node->children.size(), 1);
247  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
248 
249  std::string optional_average_field;
250  if (node->average_result) {
251  if (task_type_ == TaskType::kMultiClfGrovePerClass) {
252  CHECK(task_param_.grove_per_class);
253  CHECK_EQ(task_param_.leaf_vector_size, 1);
254  CHECK_GT(task_param_.num_class, 1);
255  CHECK_EQ(node->num_tree % task_param_.num_class, 0)
256  << "Expected the number of trees to be divisible by the number of classes";
257  int num_boosting_round = node->num_tree / static_cast<int>(task_param_.num_class);
258  optional_average_field = fmt::format(" / {}", num_boosting_round);
259  } else {
260  CHECK(task_type_ == TaskType::kBinaryClfRegr
261  || task_type_ == TaskType::kMultiClfProbDistLeaf);
262  CHECK_EQ(task_param_.num_class, task_param_.leaf_vector_size);
263  CHECK(!task_param_.grove_per_class);
264  optional_average_field = fmt::format(" / {}", node->num_tree);
265  }
266  }
267  if (task_param_.num_class > 1) {
268  AppendToBuffer(dest,
269  fmt::format(native::main_end_multiclass_template,
270  "num_class"_a = task_param_.num_class,
271  "optional_average_field"_a = optional_average_field,
272  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
273  "leaf_output_type"_a = leaf_output_type),
274  indent);
275  } else {
276  AppendToBuffer(dest,
277  fmt::format(native::main_end_template,
278  "optional_average_field"_a = optional_average_field,
279  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
280  "leaf_output_type"_a = leaf_output_type),
281  indent);
282  }
283  }
284 
285  template <typename ThresholdType, typename LeafOutputType>
286  void HandleACNode(const AccumulatorContextNode* node,
287  const std::string& dest,
288  size_t indent) {
289  const std::string leaf_output_type
290  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
291  if (task_param_.num_class > 1) {
292  AppendToBuffer(dest,
293  fmt::format("{leaf_output_type} sum[{num_class}] = {{0}};\n"
294  "unsigned int tmp;\n"
295  "int nid, cond, fid; /* used for folded subtrees */\n",
296  "num_class"_a = task_param_.num_class,
297  "leaf_output_type"_a = leaf_output_type), indent);
298  } else {
299  AppendToBuffer(dest,
300  fmt::format("{leaf_output_type} sum = ({leaf_output_type})0;\n"
301  "unsigned int tmp;\n"
302  "int nid, cond, fid; /* used for folded subtrees */\n",
303  "leaf_output_type"_a = leaf_output_type),
304  indent);
305  }
306  for (ASTNode* child : node->children) {
307  WalkAST<ThresholdType, LeafOutputType>(child, dest, indent);
308  }
309  }
310 
311  template <typename ThresholdType, typename LeafOutputType>
312  void HandleCondNode(const ConditionNode* node,
313  const std::string& dest,
314  size_t indent) {
316  std::string condition_with_na_check;
317  if ( (t = dynamic_cast<const NumericalConditionNode<ThresholdType>*>(node)) ) {
318  /* numerical split */
319  std::string condition = ExtractNumericalCondition(t);
320  const char* condition_with_na_check_template
321  = (node->default_left) ?
322  "!(data[{split_index}].missing != -1) || ({condition})"
323  : " (data[{split_index}].missing != -1) && ({condition})";
324  condition_with_na_check
325  = fmt::format(condition_with_na_check_template,
326  "split_index"_a = node->split_index,
327  "condition"_a = condition);
328  } else { /* categorical split */
329  const CategoricalConditionNode* t2 = dynamic_cast<const CategoricalConditionNode*>(node);
330  CHECK(t2);
331  condition_with_na_check = ExtractCategoricalCondition(t2);
332  }
333  if (node->children[0]->data_count && node->children[1]->data_count) {
334  const size_t left_freq = node->children[0]->data_count.value();
335  const size_t right_freq = node->children[1]->data_count.value();
336  condition_with_na_check
337  = fmt::format(" {keyword}( {condition} ) ",
338  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
339  "condition"_a = condition_with_na_check);
340  }
341  AppendToBuffer(dest,
342  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
343  CHECK_EQ(node->children.size(), 2);
344  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
345  AppendToBuffer(dest, "} else {\n", indent);
346  WalkAST<ThresholdType, LeafOutputType>(node->children[1], dest, indent + 2);
347  AppendToBuffer(dest, "}\n", indent);
348  }
349 
350  template <typename ThresholdType, typename LeafOutputType>
351  void HandleOutputNode(const OutputNode<LeafOutputType>* node,
352  const std::string& dest,
353  size_t indent) {
354  AppendToBuffer(dest, RenderOutputStatement(node), indent);
355  CHECK_EQ(node->children.size(), 0);
356  }
357 
358  template <typename ThresholdType, typename LeafOutputType>
359  void HandleTUNode(const TranslationUnitNode* node,
360  const std::string& dest,
361  size_t indent) {
362  const int unit_id = node->unit_id;
363  const std::string new_file = fmt::format("tu{}.c", unit_id);
364  const std::string leaf_output_type
365  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
366 
367  std::string unit_function_name, unit_function_signature,
368  unit_function_call_signature;
369  if (task_param_.num_class > 1) {
370  unit_function_name
371  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
372  unit_function_signature
373  = fmt::format("void {function_name}(union Entry* data, {leaf_output_type}* result)",
374  "function_name"_a = unit_function_name,
375  "leaf_output_type"_a = leaf_output_type);
376  unit_function_call_signature
377  = fmt::format("{}(data, sum);\n", unit_function_name);
378  } else {
379  unit_function_name
380  = fmt::format("predict_margin_unit{}", unit_id);
381  unit_function_signature
382  = fmt::format("{leaf_output_type} {function_name}(union Entry* data)",
383  "function_name"_a = unit_function_name,
384  "leaf_output_type"_a = leaf_output_type);
385  unit_function_call_signature
386  = fmt::format("sum += {}(data);\n", unit_function_name);
387  }
388  AppendToBuffer(dest, unit_function_call_signature, indent);
389  AppendToBuffer(new_file,
390  fmt::format("#include \"header.h\"\n"
391  "{} {{\n", unit_function_signature), 0);
392  CHECK_EQ(node->children.size(), 1);
393  WalkAST<ThresholdType, LeafOutputType>(node->children[0], new_file, 2);
394  if (task_param_.num_class > 1) {
395  AppendToBuffer(new_file,
396  fmt::format(" for (int i = 0; i < {num_class}; ++i) {{\n"
397  " result[i] += sum[i];\n"
398  " }}\n"
399  "}}\n",
400  "num_class"_a = task_param_.num_class), 0);
401  } else {
402  AppendToBuffer(new_file, " return sum;\n}\n", 0);
403  }
404  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
405  }
406 
407  template <typename ThresholdType, typename LeafOutputType>
408  void HandleQNode(const QuantizerNode<ThresholdType>* node,
409  const std::string& dest,
410  size_t indent) {
411  const std::string threshold_type
412  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
413  /* render arrays needed to convert feature values into bin indices */
414  std::string array_threshold, array_th_begin, array_th_len;
415  // threshold[] : list of all thresholds that occur at least once in the
416  // ensemble model. For each feature, an ascending list of unique
417  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
418  // of the threshold[] array stores the threshold list for feature i.
419  size_t total_num_threshold;
420  // to hold total number of (distinct) thresholds
421  {
422  common_util::ArrayFormatter formatter(80, 2);
423  for (const auto& e : node->cut_pts) {
424  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
425  // cut_pts[i][k] stores the k-th threshold of feature i.
426  for (auto v : e) {
427  formatter << v;
428  }
429  }
430  array_threshold = formatter.str();
431  }
432  {
433  common_util::ArrayFormatter formatter(80, 2);
434  size_t accum = 0; // used to compute cumulative sum over threshold counts
435  for (const auto& e : node->cut_pts) {
436  formatter << accum;
437  accum += e.size(); // e.size() = number of thresholds for each feature
438  }
439  total_num_threshold = accum;
440  array_th_begin = formatter.str();
441  }
442  {
443  common_util::ArrayFormatter formatter(80, 2);
444  for (const auto& e : node->cut_pts) {
445  formatter << e.size();
446  }
447  array_th_len = formatter.str();
448  }
449  if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
450  PrependToBuffer(dest,
451  fmt::format(native::qnode_template,
452  "total_num_threshold"_a = total_num_threshold,
453  "threshold_type"_a = threshold_type),
454  0);
455  AppendToBuffer(dest,
456  fmt::format(native::quantize_loop_template,
457  "num_feature"_a = num_feature_), indent);
458  }
459  if (!array_threshold.empty()) {
460  PrependToBuffer(dest,
461  fmt::format("static const {threshold_type} threshold[] = {{\n"
462  "{array_threshold}\n"
463  "}};\n",
464  "array_threshold"_a = array_threshold,
465  "threshold_type"_a = threshold_type),
466  0);
467  }
468  if (!array_th_begin.empty()) {
469  PrependToBuffer(dest,
470  fmt::format("static const int th_begin[] = {{\n"
471  "{array_th_begin}\n"
472  "}};\n", "array_th_begin"_a = array_th_begin), 0);
473  }
474  if (!array_th_len.empty()) {
475  PrependToBuffer(dest,
476  fmt::format("static const int th_len[] = {{\n"
477  "{array_th_len}\n"
478  "}};\n", "array_th_len"_a = array_th_len), 0);
479  }
480  CHECK_EQ(node->children.size(), 1);
481  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent);
482  }
483 
484  template <typename ThresholdType, typename LeafOutputType>
485  void HandleCodeFolderNode(const CodeFolderNode* node,
486  const std::string& dest,
487  size_t indent) {
488  CHECK_EQ(node->children.size(), 1);
489  const int node_id = node->children[0]->node_id;
490  const int tree_id = node->children[0]->tree_id;
491 
492  /* render arrays needed for folding subtrees */
493  std::string array_nodes, array_cat_bitmap, array_cat_begin;
494  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
495  const std::string node_array_name
496  = fmt::format("node_tree{}_node{}", tree_id, node_id);
497  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
498  // make all categorical splits in a particular
499  // subtree
500  const std::string cat_bitmap_name
501  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
502  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
503  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
504  // belongs to the i-th (categorical) split
505  const std::string cat_begin_name
506  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
507 
508  std::string output_switch_statement;
509  Operator common_comp_op;
510  common_util::RenderCodeFolderArrays<ThresholdType, LeafOutputType>(node, param.quantize, false,
511  "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
512  [this](const OutputNode<LeafOutputType>* node) { return RenderOutputStatement(node); },
513  &array_nodes, &array_cat_bitmap, &array_cat_begin,
514  &output_switch_statement, &common_comp_op);
515  if (!array_nodes.empty()) {
516  AppendToBuffer("header.h",
517  fmt::format("extern const struct Node {node_array_name}[];\n",
518  "node_array_name"_a = node_array_name), 0);
519  AppendToBuffer("arrays.c",
520  fmt::format("const struct Node {node_array_name}[] = {{\n"
521  "{array_nodes}\n"
522  "}};\n",
523  "node_array_name"_a = node_array_name,
524  "array_nodes"_a = array_nodes), 0);
525  }
526 
527  if (!array_cat_bitmap.empty()) {
528  AppendToBuffer("header.h",
529  fmt::format("extern const uint64_t {cat_bitmap_name}[];\n",
530  "cat_bitmap_name"_a = cat_bitmap_name), 0);
531  AppendToBuffer("arrays.c",
532  fmt::format("const uint64_t {cat_bitmap_name}[] = {{\n"
533  "{array_cat_bitmap}\n"
534  "}};\n",
535  "cat_bitmap_name"_a = cat_bitmap_name,
536  "array_cat_bitmap"_a = array_cat_bitmap), 0);
537  }
538 
539  if (!array_cat_begin.empty()) {
540  AppendToBuffer("header.h",
541  fmt::format("extern const size_t {cat_begin_name}[];\n",
542  "cat_begin_name"_a = cat_begin_name), 0);
543  AppendToBuffer("arrays.c",
544  fmt::format("const size_t {cat_begin_name}[] = {{\n"
545  "{array_cat_begin}\n"
546  "}};\n",
547  "cat_begin_name"_a = cat_begin_name,
548  "array_cat_begin"_a = array_cat_begin), 0);
549  }
550 
551  if (array_nodes.empty()) {
552  /* folded code consists of a single leaf node */
553  AppendToBuffer(dest,
554  fmt::format("nid = -1;\n"
555  "{output_switch_statement}\n",
556  "output_switch_statement"_a
557  = output_switch_statement), indent);
558  } else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
559  AppendToBuffer(dest,
560  fmt::format(native::eval_loop_template,
561  "node_array_name"_a = node_array_name,
562  "cat_bitmap_name"_a = cat_bitmap_name,
563  "cat_begin_name"_a = cat_begin_name,
564  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
565  "comp_op"_a = OpName(common_comp_op),
566  "output_switch_statement"_a
567  = output_switch_statement), indent);
568  } else {
569  AppendToBuffer(dest,
570  fmt::format(native::eval_loop_template_without_categorical_feature,
571  "node_array_name"_a = node_array_name,
572  "data_field"_a = (param.quantize > 0 ? "qvalue" : "fvalue"),
573  "comp_op"_a = OpName(common_comp_op),
574  "output_switch_statement"_a
575  = output_switch_statement), indent);
576  }
577  }
578 
579  template <typename ThresholdType>
580  inline std::string
581  ExtractNumericalCondition(const NumericalConditionNode<ThresholdType>* node) {
582  const std::string threshold_type
583  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
584  std::string result;
585  if (node->quantized) { // quantized threshold
586  std::string lhs = fmt::format("data[{split_index}].qvalue",
587  "split_index"_a = node->split_index);
588  result = fmt::format("{lhs} {opname} {threshold}",
589  "lhs"_a = lhs,
590  "opname"_a = OpName(node->op),
591  "threshold"_a = node->threshold.int_val);
592  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
593  // According to IEEE 754, the result of comparison [lhs] < infinity
594  // must be identical for all finite [lhs]. Same goes for operator >.
595  result = (CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
596  ? "1" : "0");
597  } else { // finite threshold
598  std::string lhs = fmt::format("data[{split_index}].fvalue",
599  "split_index"_a = node->split_index);
600  result
601  = fmt::format("{lhs} {opname} ({threshold_type}){threshold}",
602  "lhs"_a = lhs,
603  "opname"_a = OpName(node->op),
604  "threshold_type"_a = threshold_type,
605  "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
606  }
607  return result;
608  }
609 
610  inline std::string
611  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
612  std::string result;
613  std::vector<uint64_t> bitmap
614  = common_util::GetCategoricalBitmap(node->matching_categories);
615  CHECK_GE(bitmap.size(), 1);
616  bool all_zeros = true;
617  for (uint64_t e : bitmap) {
618  all_zeros &= (e == 0);
619  }
620  if (all_zeros) {
621  result = "0";
622  } else {
623  std::ostringstream oss;
624  const std::string right_categories_flag = (node->categories_list_right_child ? "!" : "");
625  if (node->default_left) {
626  oss << fmt::format(
627  "data[{split_index}].missing == -1 || {right_categories_flag}("
628  "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
629  "split_index"_a = node->split_index,
630  "right_categories_flag"_a = right_categories_flag);
631  } else {
632  oss << fmt::format(
633  "data[{split_index}].missing != -1 && {right_categories_flag}("
634  "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
635  "split_index"_a = node->split_index,
636  "right_categories_flag"_a = right_categories_flag);
637  }
638  oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
639  << bitmap[0] << "U >> tmp) & 1) )";
640  for (size_t i = 1; i < bitmap.size(); ++i) {
641  oss << " || (tmp >= " << (i * 64)
642  << " && tmp < " << ((i + 1) * 64)
643  << " && (( (uint64_t)" << bitmap[i]
644  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
645  }
646  oss << ")";
647  result = oss.str();
648  }
649  return result;
650  }
651 
652  inline std::string
653  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
654  common_util::ArrayFormatter formatter(80, 2);
655  for (int fid = 0; fid < num_feature_; ++fid) {
656  formatter << (is_categorical[fid] ? 1 : 0);
657  }
658  return formatter.str();
659  }
660 
661  template <typename LeafOutputType>
662  inline std::string RenderOutputStatement(const OutputNode<LeafOutputType>* node) {
663  const std::string leaf_output_type
664  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
665  std::string output_statement;
666  if (task_param_.num_class > 1) {
667  if (node->is_vector) {
668  // multi-class classification with random forest
669  CHECK_EQ(node->vector.size(), static_cast<size_t>(task_param_.num_class))
670  << "Ill-formed model: leaf vector must be of length [num_class]";
671  for (size_t group_id = 0; group_id < task_param_.num_class; ++group_id) {
672  output_statement
673  += fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n",
674  "group_id"_a = group_id,
675  "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]),
676  "leaf_output_type"_a = leaf_output_type);
677  }
678  } else {
679  // multi-class classification with gradient boosted trees
680  output_statement
681  = fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n",
682  "group_id"_a = node->tree_id % task_param_.num_class,
683  "output"_a = common_util::ToStringHighPrecision(node->scalar),
684  "leaf_output_type"_a = leaf_output_type);
685  }
686  } else {
687  output_statement
688  = fmt::format("sum += ({leaf_output_type}){output};\n",
689  "output"_a = common_util::ToStringHighPrecision(node->scalar),
690  "leaf_output_type"_a = leaf_output_type);
691  }
692  return output_statement;
693  }
694 };
695 
697 .describe("AST-based compiler that produces C code")
698 .set_body([](const CompilerParam& param) -> Compiler* {
699  return new ASTNativeCompiler(param);
700  });
701 } // namespace compiler
702 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:658
Parameters for tree compiler.
branch annotator class
Definition: annotator.h:17
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:43
TaskType
Enum type representing the task type.
Definition: tree.h:94
parameters for tree compiler
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:63
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
template for header
TaskType task_type
Task type.
Definition: tree.h:652
interface of compiler
Definition: compiler.h:54
Interface of compiler that compiles a tree ensemble model.
template for main function
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:158
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:241
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Look up C symbols corresponding to TypeInfo.
Utilities for code folding.
Definition: compiler.h:27
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:92
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Function to generate bitmaps for categorical splits.
TaskParameter task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:656
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
AST Builder class.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:130
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
std::string str()
obtain formatted text containing the rendered array
Definition: format_util.h:99
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:177
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:169
Branch annotation tools.
thin wrapper for tree ensemble model
Definition: tree.h:615
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:650
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:184
Operator
comparison operators
Definition: base.h:26