Treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <treelite/compiler.h>
9 #include <treelite/annotator.h>
10 #include <fmt/format.h>
11 #include <rapidjson/stringbuffer.h>
12 #include <rapidjson/writer.h>
13 #include <algorithm>
14 #include <fstream>
15 #include <unordered_map>
16 #include <queue>
17 #include <cstdio>
18 #include <cmath>
19 #include <cstdint>
20 #include "./ast_native.h"
21 #include "./pred_transform.h"
22 #include "./ast/builder.h"
23 #include "./native/main_template.h"
28 #include "./common/format_util.h"
30 
31 #if defined(_MSC_VER) || defined(_WIN32)
32 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
33 #else
34 #define DLLEXPORT_KEYWORD ""
35 #endif
36 
37 using namespace fmt::literals;
38 
39 namespace treelite {
40 namespace compiler {
41 
43  public:
44  explicit ASTNativeCompilerImpl(const CompilerParam& param) : param_(param) {}
45 
46  template <typename ThresholdType, typename LeafOutputType>
48  CompiledModel cm;
49  cm.backend = "native";
50 
51  TREELITE_CHECK(model.task_type != TaskType::kMultiClfCategLeaf)
52  << "Model task type unsupported by ASTNativeCompiler";
53  TREELITE_CHECK(model.task_param.output_type == TaskParam::OutputType::kFloat)
54  << "ASTNativeCompiler only supports models with float output";
55 
56  num_feature_ = model.num_feature;
57  task_type_ = model.task_type;
58  task_param_ = model.task_param;
59  pred_transform_ = model.param.pred_transform;
60  sigmoid_alpha_ = model.param.sigmoid_alpha;
61  ratio_c_ = model.param.ratio_c;
62  global_bias_ = model.param.global_bias;
63  files_.clear();
64 
66  builder.BuildAST(model);
67  if (builder.FoldCode(param_.code_folding_req) || param_.quantize > 0) {
68  // is_categorical[i] : is i-th feature categorical?
69  array_is_categorical_
70  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
71  }
72  if (param_.annotate_in != "NULL") {
73  BranchAnnotator annotator;
74  std::ifstream fi(param_.annotate_in.c_str());
75  annotator.Load(fi);
76  const auto annotation = annotator.Get();
77  builder.LoadDataCounts(annotation);
78  TREELITE_LOG(INFO) << "Loading node frequencies from `"
79  << param_.annotate_in << "'";
80  }
81  builder.Split(param_.parallel_comp);
82  if (param_.quantize > 0) {
83  builder.QuantizeThresholds();
84  }
85 
86  {
87  const char* destfile = getenv("TREELITE_DUMP_AST");
88  if (destfile) {
89  std::ofstream os(destfile);
90  os << builder.GetDump() << std::endl;
91  }
92  }
93 
94  WalkAST<ThresholdType, LeafOutputType>(builder.GetRootNode(), "main.c", 0);
95  if (files_.count("arrays.c") > 0) {
96  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
97  }
98 
99  {
100  /* write recipe.json */
101  rapidjson::StringBuffer os;
102  rapidjson::Writer<rapidjson::StringBuffer> writer(os);
103 
104  writer.StartObject();
105  writer.Key("target");
106  writer.String(param_.native_lib_name.data(), param_.native_lib_name.size());
107  writer.Key("sources");
108  writer.StartArray();
109  for (const auto& kv : files_) {
110  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
111  const size_t line_count
112  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
113  writer.StartObject();
114  writer.Key("name");
115  std::string name = kv.first.substr(0, kv.first.length() - 2);
116  writer.String(name.data(), name.size());
117  writer.Key("length");
118  writer.Uint64(line_count);
119  writer.EndObject();
120  }
121  }
122  writer.EndArray();
123  writer.EndObject();
124 
125  files_["recipe.json"] = CompiledModel::FileEntry(os.GetString());
126  }
127  cm.files = std::move(files_);
128  return cm;
129  }
130 
131  CompiledModel Compile(const Model& model) {
132  TREELITE_CHECK(model.GetLeafOutputType() != TypeInfo::kUInt32)
133  << "Integer leaf outputs not yet supported";
134  this->pred_tranform_func_ = PredTransformFunction("native", model);
135  return model.Dispatch([this](const auto& model_handle) {
136  return this->CompileImpl(model_handle);
137  });
138  }
139 
140  CompilerParam QueryParam() const {
141  return param_;
142  }
143 
144  private:
145  CompilerParam param_;
146  int num_feature_;
147  TaskType task_type_;
148  TaskParam task_param_;
149  std::string pred_transform_;
150  float sigmoid_alpha_;
151  float ratio_c_;
152  float global_bias_;
153  std::string pred_tranform_func_;
154  std::string array_is_categorical_;
155  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
156 
157  template <typename ThresholdType, typename LeafOutputType>
158  void WalkAST(const ASTNode* node,
159  const std::string& dest,
160  size_t indent) {
161  const MainNode* t1;
162  const AccumulatorContextNode* t2;
163  const ConditionNode* t3;
164  const OutputNode<LeafOutputType>* t4;
165  const TranslationUnitNode* t5;
167  const CodeFolderNode* t7;
168  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
169  HandleMainNode<ThresholdType, LeafOutputType>(t1, dest, indent);
170  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
171  HandleACNode<ThresholdType, LeafOutputType>(t2, dest, indent);
172  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
173  HandleCondNode<ThresholdType, LeafOutputType>(t3, dest, indent);
174  } else if ( (t4 = dynamic_cast<const OutputNode<LeafOutputType>*>(node)) ) {
175  HandleOutputNode<ThresholdType, LeafOutputType>(t4, dest, indent);
176  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
177  HandleTUNode<ThresholdType, LeafOutputType>(t5, dest, indent);
178  } else if ( (t6 = dynamic_cast<const QuantizerNode<ThresholdType>*>(node)) ) {
179  HandleQNode<ThresholdType, LeafOutputType>(t6, dest, indent);
180  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
181  HandleCodeFolderNode<ThresholdType, LeafOutputType>(t7, dest, indent);
182  } else {
183  TREELITE_LOG(FATAL) << "Unrecognized AST node type";
184  }
185  }
186 
187  // append content to a given buffer, with given level of indentation
188  inline void AppendToBuffer(const std::string& dest,
189  const std::string& content,
190  size_t indent) {
191  files_[dest].content += common_util::IndentMultiLineString(content, indent);
192  }
193 
194  // prepend content to a given buffer, with given level of indentation
195  inline void PrependToBuffer(const std::string& dest,
196  const std::string& content,
197  size_t indent) {
198  files_[dest].content
199  = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
200  }
201 
202  template <typename ThresholdType, typename LeafOutputType>
203  void HandleMainNode(const MainNode* node,
204  const std::string& dest,
205  size_t indent) {
206  const std::string threshold_type
207  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
208  const std::string leaf_output_type
209  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
210  const std::string predict_function_signature
211  = (task_param_.num_class > 1) ?
212  fmt::format("size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)",
213  leaf_output_type)
214  : fmt::format("{} predict(union Entry* data, int pred_margin)",
215  leaf_output_type);
216 
217  if (!array_is_categorical_.empty()) {
218  array_is_categorical_
219  = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
220  array_is_categorical_);
221  }
222 
223  const std::string query_functions_definition
224  = fmt::format(native::query_functions_definition_template,
225  "num_class"_a = task_param_.num_class,
226  "num_feature"_a = num_feature_,
227  "pred_transform"_a = pred_transform_,
228  "sigmoid_alpha"_a = sigmoid_alpha_,
229  "ratio_c"_a = ratio_c_,
230  "global_bias"_a = global_bias_,
231  "threshold_type_str"_a = TypeInfoToString(TypeToInfo<ThresholdType>()),
232  "leaf_output_type_str"_a = TypeInfoToString(TypeToInfo<LeafOutputType>()));
233 
234  AppendToBuffer(dest,
235  fmt::format(native::main_start_template,
236  "array_is_categorical"_a = array_is_categorical_,
237  "query_functions_definition"_a = query_functions_definition,
238  "pred_transform_function"_a = pred_tranform_func_,
239  "predict_function_signature"_a = predict_function_signature),
240  indent);
241  const std::string query_functions_prototype
242  = fmt::format(native::query_functions_prototype_template,
243  "dllexport"_a = DLLEXPORT_KEYWORD);
244  AppendToBuffer("header.h",
245  fmt::format(native::header_template,
246  "dllexport"_a = DLLEXPORT_KEYWORD,
247  "predict_function_signature"_a = predict_function_signature,
248  "query_functions_prototype"_a = query_functions_prototype,
249  "threshold_type"_a = threshold_type,
250  "threshold_type_Node"_a = (param_.quantize > 0 ? std::string("int") : threshold_type)),
251  indent);
252 
253  TREELITE_CHECK_EQ(node->children.size(), 1);
254  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
255 
256  std::string optional_average_field;
257  if (node->average_result) {
258  if (task_type_ == TaskType::kMultiClfGrovePerClass) {
259  TREELITE_CHECK(task_param_.grove_per_class);
260  TREELITE_CHECK_EQ(task_param_.leaf_vector_size, 1);
261  TREELITE_CHECK_GT(task_param_.num_class, 1);
262  TREELITE_CHECK_EQ(node->num_tree % task_param_.num_class, 0)
263  << "Expected the number of trees to be divisible by the number of classes";
264  int num_boosting_round = node->num_tree / static_cast<int>(task_param_.num_class);
265  optional_average_field = fmt::format(" / {}", num_boosting_round);
266  } else {
267  TREELITE_CHECK(task_type_ == TaskType::kBinaryClfRegr
268  || task_type_ == TaskType::kMultiClfProbDistLeaf);
269  TREELITE_CHECK_EQ(task_param_.num_class, task_param_.leaf_vector_size);
270  TREELITE_CHECK(!task_param_.grove_per_class);
271  optional_average_field = fmt::format(" / {}", node->num_tree);
272  }
273  }
274  if (task_param_.num_class > 1) {
275  AppendToBuffer(dest,
276  fmt::format(native::main_end_multiclass_template,
277  "num_class"_a = task_param_.num_class,
278  "optional_average_field"_a = optional_average_field,
279  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
280  "leaf_output_type"_a = leaf_output_type),
281  indent);
282  } else {
283  AppendToBuffer(dest,
284  fmt::format(native::main_end_template,
285  "optional_average_field"_a = optional_average_field,
286  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
287  "leaf_output_type"_a = leaf_output_type),
288  indent);
289  }
290  }
291 
292  template <typename ThresholdType, typename LeafOutputType>
293  void HandleACNode(const AccumulatorContextNode* node,
294  const std::string& dest,
295  size_t indent) {
296  const std::string leaf_output_type
297  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
298  if (task_param_.num_class > 1) {
299  AppendToBuffer(dest,
300  fmt::format("{leaf_output_type} sum[{num_class}] = {{0}};\n"
301  "unsigned int tmp;\n"
302  "int nid, cond, fid; /* used for folded subtrees */\n",
303  "num_class"_a = task_param_.num_class,
304  "leaf_output_type"_a = leaf_output_type), indent);
305  } else {
306  AppendToBuffer(dest,
307  fmt::format("{leaf_output_type} sum = ({leaf_output_type})0;\n"
308  "unsigned int tmp;\n"
309  "int nid, cond, fid; /* used for folded subtrees */\n",
310  "leaf_output_type"_a = leaf_output_type),
311  indent);
312  }
313  for (ASTNode* child : node->children) {
314  WalkAST<ThresholdType, LeafOutputType>(child, dest, indent);
315  }
316  }
317 
318  template <typename ThresholdType, typename LeafOutputType>
319  void HandleCondNode(const ConditionNode* node,
320  const std::string& dest,
321  size_t indent) {
323  std::string condition_with_na_check;
324  if ( (t = dynamic_cast<const NumericalConditionNode<ThresholdType>*>(node)) ) {
325  /* numerical split */
326  std::string condition = ExtractNumericalCondition(t);
327  const char* condition_with_na_check_template
328  = (node->default_left) ?
329  "!(data[{split_index}].missing != -1) || ({condition})"
330  : " (data[{split_index}].missing != -1) && ({condition})";
331  condition_with_na_check
332  = fmt::format(condition_with_na_check_template,
333  "split_index"_a = node->split_index,
334  "condition"_a = condition);
335  } else { /* categorical split */
336  const CategoricalConditionNode* t2 = dynamic_cast<const CategoricalConditionNode*>(node);
337  TREELITE_CHECK(t2);
338  condition_with_na_check = ExtractCategoricalCondition(t2);
339  }
340  if (node->children[0]->data_count && node->children[1]->data_count) {
341  const uint64_t left_freq = *node->children[0]->data_count;
342  const uint64_t right_freq = *node->children[1]->data_count;
343  condition_with_na_check
344  = fmt::format(" {keyword}( {condition} ) ",
345  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
346  "condition"_a = condition_with_na_check);
347  }
348  AppendToBuffer(dest,
349  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
350  TREELITE_CHECK_EQ(node->children.size(), 2);
351  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
352  AppendToBuffer(dest, "} else {\n", indent);
353  WalkAST<ThresholdType, LeafOutputType>(node->children[1], dest, indent + 2);
354  AppendToBuffer(dest, "}\n", indent);
355  }
356 
357  template <typename ThresholdType, typename LeafOutputType>
358  void HandleOutputNode(const OutputNode<LeafOutputType>* node,
359  const std::string& dest,
360  size_t indent) {
361  AppendToBuffer(dest, RenderOutputStatement(node), indent);
362  TREELITE_CHECK_EQ(node->children.size(), 0);
363  }
364 
365  template <typename ThresholdType, typename LeafOutputType>
366  void HandleTUNode(const TranslationUnitNode* node,
367  const std::string& dest,
368  size_t indent) {
369  const int unit_id = node->unit_id;
370  const std::string new_file = fmt::format("tu{}.c", unit_id);
371  const std::string leaf_output_type
372  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
373 
374  std::string unit_function_name, unit_function_signature,
375  unit_function_call_signature;
376  if (task_param_.num_class > 1) {
377  unit_function_name
378  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
379  unit_function_signature
380  = fmt::format("void {function_name}(union Entry* data, {leaf_output_type}* result)",
381  "function_name"_a = unit_function_name,
382  "leaf_output_type"_a = leaf_output_type);
383  unit_function_call_signature
384  = fmt::format("{}(data, sum);\n", unit_function_name);
385  } else {
386  unit_function_name
387  = fmt::format("predict_margin_unit{}", unit_id);
388  unit_function_signature
389  = fmt::format("{leaf_output_type} {function_name}(union Entry* data)",
390  "function_name"_a = unit_function_name,
391  "leaf_output_type"_a = leaf_output_type);
392  unit_function_call_signature
393  = fmt::format("sum += {}(data);\n", unit_function_name);
394  }
395  AppendToBuffer(dest, unit_function_call_signature, indent);
396  AppendToBuffer(new_file,
397  fmt::format("#include \"header.h\"\n"
398  "{} {{\n", unit_function_signature), 0);
399  TREELITE_CHECK_EQ(node->children.size(), 1);
400  WalkAST<ThresholdType, LeafOutputType>(node->children[0], new_file, 2);
401  if (task_param_.num_class > 1) {
402  AppendToBuffer(new_file,
403  fmt::format(" for (int i = 0; i < {num_class}; ++i) {{\n"
404  " result[i] += sum[i];\n"
405  " }}\n"
406  "}}\n",
407  "num_class"_a = task_param_.num_class), 0);
408  } else {
409  AppendToBuffer(new_file, " return sum;\n}\n", 0);
410  }
411  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
412  }
413 
414  template <typename ThresholdType, typename LeafOutputType>
415  void HandleQNode(const QuantizerNode<ThresholdType>* node,
416  const std::string& dest,
417  size_t indent) {
418  const std::string threshold_type
419  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
420  /* render arrays needed to convert feature values into bin indices */
421  std::string array_threshold, array_th_begin, array_th_len;
422  // threshold[] : list of all thresholds that occur at least once in the
423  // ensemble model. For each feature, an ascending list of unique
424  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
425  // of the threshold[] array stores the threshold list for feature i.
426  size_t total_num_threshold;
427  // to hold total number of (distinct) thresholds
428  {
429  common_util::ArrayFormatter formatter(80, 2);
430  for (const auto& e : node->cut_pts) {
431  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
432  // cut_pts[i][k] stores the k-th threshold of feature i.
433  for (auto v : e) {
434  formatter << v;
435  }
436  }
437  array_threshold = formatter.str();
438  }
439  {
440  common_util::ArrayFormatter formatter(80, 2);
441  size_t accum = 0; // used to compute cumulative sum over threshold counts
442  for (const auto& e : node->cut_pts) {
443  formatter << accum;
444  accum += e.size(); // e.size() = number of thresholds for each feature
445  }
446  total_num_threshold = accum;
447  array_th_begin = formatter.str();
448  }
449  {
450  common_util::ArrayFormatter formatter(80, 2);
451  for (const auto& e : node->cut_pts) {
452  formatter << e.size();
453  }
454  array_th_len = formatter.str();
455  }
456  if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
457  PrependToBuffer(dest,
458  fmt::format(native::qnode_template,
459  "total_num_threshold"_a = total_num_threshold,
460  "threshold_type"_a = threshold_type),
461  0);
462  AppendToBuffer(dest,
463  fmt::format(native::quantize_loop_template,
464  "num_feature"_a = num_feature_), indent);
465  }
466  if (!array_threshold.empty()) {
467  PrependToBuffer(dest,
468  fmt::format("static const {threshold_type} threshold[] = {{\n"
469  "{array_threshold}\n"
470  "}};\n",
471  "array_threshold"_a = array_threshold,
472  "threshold_type"_a = threshold_type),
473  0);
474  }
475  if (!array_th_begin.empty()) {
476  PrependToBuffer(dest,
477  fmt::format("static const int th_begin[] = {{\n"
478  "{array_th_begin}\n"
479  "}};\n", "array_th_begin"_a = array_th_begin), 0);
480  }
481  if (!array_th_len.empty()) {
482  PrependToBuffer(dest,
483  fmt::format("static const int th_len[] = {{\n"
484  "{array_th_len}\n"
485  "}};\n", "array_th_len"_a = array_th_len), 0);
486  }
487  TREELITE_CHECK_EQ(node->children.size(), 1);
488  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent);
489  }
490 
491  template <typename ThresholdType, typename LeafOutputType>
492  void HandleCodeFolderNode(const CodeFolderNode* node,
493  const std::string& dest,
494  size_t indent) {
495  TREELITE_CHECK_EQ(node->children.size(), 1);
496  const int node_id = node->children[0]->node_id;
497  const int tree_id = node->children[0]->tree_id;
498 
499  /* render arrays needed for folding subtrees */
500  std::string array_nodes, array_cat_bitmap, array_cat_begin;
501  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
502  const std::string node_array_name
503  = fmt::format("node_tree{}_node{}", tree_id, node_id);
504  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
505  // make all categorical splits in a particular
506  // subtree
507  const std::string cat_bitmap_name
508  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
509  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
510  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
511  // belongs to the i-th (categorical) split
512  const std::string cat_begin_name
513  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
514 
515  std::string output_switch_statement;
516  Operator common_comp_op;
517  common_util::RenderCodeFolderArrays<ThresholdType, LeafOutputType>(node, param_.quantize,
518  false, "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
519  [this](const OutputNode<LeafOutputType>* node) { return RenderOutputStatement(node); },
520  &array_nodes, &array_cat_bitmap, &array_cat_begin, &output_switch_statement,
521  &common_comp_op);
522  if (!array_nodes.empty()) {
523  AppendToBuffer("header.h",
524  fmt::format("extern const struct Node {node_array_name}[];\n",
525  "node_array_name"_a = node_array_name), 0);
526  AppendToBuffer("arrays.c",
527  fmt::format("const struct Node {node_array_name}[] = {{\n"
528  "{array_nodes}\n"
529  "}};\n",
530  "node_array_name"_a = node_array_name,
531  "array_nodes"_a = array_nodes), 0);
532  }
533 
534  if (!array_cat_bitmap.empty()) {
535  AppendToBuffer("header.h",
536  fmt::format("extern const uint64_t {cat_bitmap_name}[];\n",
537  "cat_bitmap_name"_a = cat_bitmap_name), 0);
538  AppendToBuffer("arrays.c",
539  fmt::format("const uint64_t {cat_bitmap_name}[] = {{\n"
540  "{array_cat_bitmap}\n"
541  "}};\n",
542  "cat_bitmap_name"_a = cat_bitmap_name,
543  "array_cat_bitmap"_a = array_cat_bitmap), 0);
544  }
545 
546  if (!array_cat_begin.empty()) {
547  AppendToBuffer("header.h",
548  fmt::format("extern const size_t {cat_begin_name}[];\n",
549  "cat_begin_name"_a = cat_begin_name), 0);
550  AppendToBuffer("arrays.c",
551  fmt::format("const size_t {cat_begin_name}[] = {{\n"
552  "{array_cat_begin}\n"
553  "}};\n",
554  "cat_begin_name"_a = cat_begin_name,
555  "array_cat_begin"_a = array_cat_begin), 0);
556  }
557 
558  if (array_nodes.empty()) {
559  /* folded code consists of a single leaf node */
560  AppendToBuffer(dest,
561  fmt::format("nid = -1;\n"
562  "{output_switch_statement}\n",
563  "output_switch_statement"_a
564  = output_switch_statement), indent);
565  } else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
566  AppendToBuffer(dest,
567  fmt::format(native::eval_loop_template,
568  "node_array_name"_a = node_array_name,
569  "cat_bitmap_name"_a = cat_bitmap_name,
570  "cat_begin_name"_a = cat_begin_name,
571  "data_field"_a = (param_.quantize > 0 ? "qvalue" : "fvalue"),
572  "comp_op"_a = OpName(common_comp_op),
573  "output_switch_statement"_a
574  = output_switch_statement), indent);
575  } else {
576  AppendToBuffer(dest,
577  fmt::format(native::eval_loop_template_without_categorical_feature,
578  "node_array_name"_a = node_array_name,
579  "data_field"_a = (param_.quantize > 0 ? "qvalue" : "fvalue"),
580  "comp_op"_a = OpName(common_comp_op),
581  "output_switch_statement"_a
582  = output_switch_statement), indent);
583  }
584  }
585 
586  template <typename ThresholdType>
587  inline std::string
588  ExtractNumericalCondition(const NumericalConditionNode<ThresholdType>* node) {
589  const std::string threshold_type
590  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
591  std::string result;
592  if (node->quantized) { // quantized threshold
593  std::string lhs = fmt::format("data[{split_index}].qvalue",
594  "split_index"_a = node->split_index);
595  result = fmt::format("{lhs} {opname} {threshold}",
596  "lhs"_a = lhs,
597  "opname"_a = OpName(node->op),
598  "threshold"_a = node->threshold.int_val);
599  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
600  // According to IEEE 754, the result of comparison [lhs] < infinity
601  // must be identical for all finite [lhs]. Same goes for operator >.
602  result = (CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
603  ? "1" : "0");
604  } else { // finite threshold
605  std::string lhs = fmt::format("data[{split_index}].fvalue",
606  "split_index"_a = node->split_index);
607  result
608  = fmt::format("{lhs} {opname} ({threshold_type}){threshold}",
609  "lhs"_a = lhs,
610  "opname"_a = OpName(node->op),
611  "threshold_type"_a = threshold_type,
612  "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
613  }
614  return result;
615  }
616 
617  inline std::string
618  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
619  std::string result;
620  std::vector<uint64_t> bitmap
621  = common_util::GetCategoricalBitmap(node->matching_categories);
622  TREELITE_CHECK_GE(bitmap.size(), 1);
623  bool all_zeros = true;
624  for (uint64_t e : bitmap) {
625  all_zeros &= (e == 0);
626  }
627  if (all_zeros) {
628  result = "0";
629  } else {
630  std::ostringstream oss;
631  const std::string right_categories_flag = (node->categories_list_right_child ? "!" : "");
632  if (node->default_left) {
633  oss << fmt::format(
634  "data[{split_index}].missing == -1 || {right_categories_flag}("
635  "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
636  "split_index"_a = node->split_index,
637  "right_categories_flag"_a = right_categories_flag);
638  } else {
639  oss << fmt::format(
640  "data[{split_index}].missing != -1 && {right_categories_flag}("
641  "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
642  "split_index"_a = node->split_index,
643  "right_categories_flag"_a = right_categories_flag);
644  }
645 
646  oss << fmt::format(
647  "((data[{split_index}].fvalue >= 0) && "
648  "(fabsf(data[{split_index}].fvalue) <= (float)(1U << FLT_MANT_DIG)) && (",
649  "split_index"_a = node->split_index);
650  oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
651  << bitmap[0] << "U >> tmp) & 1) )";
652  for (size_t i = 1; i < bitmap.size(); ++i) {
653  oss << " || (tmp >= " << (i * 64)
654  << " && tmp < " << ((i + 1) * 64)
655  << " && (( (uint64_t)" << bitmap[i]
656  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
657  }
658  oss << ")))";
659  result = oss.str();
660  }
661  return result;
662  }
663 
664  inline std::string
665  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
666  common_util::ArrayFormatter formatter(80, 2);
667  for (int fid = 0; fid < num_feature_; ++fid) {
668  formatter << (is_categorical[fid] ? 1 : 0);
669  }
670  return formatter.str();
671  }
672 
673  template <typename LeafOutputType>
674  inline std::string RenderOutputStatement(const OutputNode<LeafOutputType>* node) {
675  const std::string leaf_output_type
676  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
677  std::string output_statement;
678  if (task_param_.num_class > 1) {
679  if (node->is_vector) {
680  // multi-class classification with random forest
681  TREELITE_CHECK_EQ(node->vector.size(), static_cast<size_t>(task_param_.num_class))
682  << "Ill-formed model: leaf vector must be of length [num_class]";
683  for (size_t group_id = 0; group_id < task_param_.num_class; ++group_id) {
684  output_statement
685  += fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n",
686  "group_id"_a = group_id,
687  "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]),
688  "leaf_output_type"_a = leaf_output_type);
689  }
690  } else {
691  // multi-class classification with gradient boosted trees
692  output_statement
693  = fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n",
694  "group_id"_a = node->tree_id % task_param_.num_class,
695  "output"_a = common_util::ToStringHighPrecision(node->scalar),
696  "leaf_output_type"_a = leaf_output_type);
697  }
698  } else {
699  output_statement
700  = fmt::format("sum += ({leaf_output_type}){output};\n",
701  "output"_a = common_util::ToStringHighPrecision(node->scalar),
702  "leaf_output_type"_a = leaf_output_type);
703  }
704  return output_statement;
705  }
706 };
707 
708 ASTNativeCompiler::ASTNativeCompiler(const CompilerParam& param)
709  : pimpl_(std::make_unique<ASTNativeCompilerImpl>(param)) {
710  if (param.verbose > 0) {
711  TREELITE_LOG(INFO) << "Using ASTNativeCompiler";
712  }
713  if (param.dump_array_as_elf > 0) {
714  TREELITE_LOG(INFO) << "Warning: 'dump_array_as_elf' parameter is not applicable "
715  "for ASTNativeCompiler";
716  }
717 }
718 
719 ASTNativeCompiler::~ASTNativeCompiler() = default;
720 
722 ASTNativeCompiler::Compile(const Model &model) {
723  return pimpl_->Compile(model);
724 }
725 
727 ASTNativeCompiler::QueryParam() const {
728  return pimpl_->QueryParam();
729 }
730 
731 } // namespace compiler
732 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:801
Parameters for tree compiler.
branch annotator class
Definition: annotator.h:21
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:192
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:181
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:44
TaskType
Enum type representing the task type.
Definition: tree.h:107
parameters for tree compiler
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:78
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:207
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:200
template for header
TaskType task_type
Task type.
Definition: tree.h:795
Interface of compiler that compiles a tree ensemble model.
int32_t num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:793
template for main function
std::vector< std::vector< uint64_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:55
template for evaluation logic for folded code
code template for QuantizerNode
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Look up C symbols corresponding to TypeInfo.
Utilities for code folding.
Definition: compiler.h:26
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:40
AST Builder class.
std::string str()
obtain formatted text containing the rendered array
Definition: format_util.h:99
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:799
C code generator.
Branch annotation tools.
thin wrapper for tree ensemble model
Definition: tree.h:734
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
void Load(std::istream &fi)
load branch annotation from a JSON file
Definition: annotator.cc:244
Operator
comparison operators
Definition: base.h:27