Treelite
ast_native.cc
Go to the documentation of this file.
1 
7 #include <treelite/compiler.h>
9 #include <treelite/annotator.h>
10 #include <fmt/format.h>
11 #include <rapidjson/stringbuffer.h>
12 #include <rapidjson/writer.h>
13 #include <algorithm>
14 #include <fstream>
15 #include <unordered_map>
16 #include <queue>
17 #include <cstdio>
18 #include <cmath>
19 #include <cstdint>
20 #include "./ast_native.h"
21 #include "./pred_transform.h"
22 #include "./ast/builder.h"
23 #include "./native/main_template.h"
28 #include "./common/format_util.h"
30 
31 #if defined(_MSC_VER) || defined(_WIN32)
32 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
33 #else
34 #define DLLEXPORT_KEYWORD ""
35 #endif
36 
37 using namespace fmt::literals;
38 
39 namespace treelite {
40 namespace compiler {
41 
43  public:
44  explicit ASTNativeCompilerImpl(const CompilerParam& param) : param_(param) {}
45 
46  template <typename ThresholdType, typename LeafOutputType>
48  CompiledModel cm;
49  cm.backend = "native";
50 
51  TREELITE_CHECK(model.task_type != TaskType::kMultiClfCategLeaf)
52  << "Model task type unsupported by ASTNativeCompiler";
53  TREELITE_CHECK(model.task_param.output_type == TaskParam::OutputType::kFloat)
54  << "ASTNativeCompiler only supports models with float output";
55 
56  num_feature_ = model.num_feature;
57  task_type_ = model.task_type;
58  task_param_ = model.task_param;
59  pred_transform_ = model.param.pred_transform;
60  sigmoid_alpha_ = model.param.sigmoid_alpha;
61  global_bias_ = model.param.global_bias;
62  files_.clear();
63 
65  builder.BuildAST(model);
66  if (builder.FoldCode(param_.code_folding_req) || param_.quantize > 0) {
67  // is_categorical[i] : is i-th feature categorical?
68  array_is_categorical_
69  = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
70  }
71  if (param_.annotate_in != "NULL") {
72  BranchAnnotator annotator;
73  std::ifstream fi(param_.annotate_in.c_str());
74  annotator.Load(fi);
75  const auto annotation = annotator.Get();
76  builder.LoadDataCounts(annotation);
77  TREELITE_LOG(INFO) << "Loading node frequencies from `"
78  << param_.annotate_in << "'";
79  }
80  builder.Split(param_.parallel_comp);
81  if (param_.quantize > 0) {
82  builder.QuantizeThresholds();
83  }
84 
85  {
86  const char* destfile = getenv("TREELITE_DUMP_AST");
87  if (destfile) {
88  std::ofstream os(destfile);
89  os << builder.GetDump() << std::endl;
90  }
91  }
92 
93  WalkAST<ThresholdType, LeafOutputType>(builder.GetRootNode(), "main.c", 0);
94  if (files_.count("arrays.c") > 0) {
95  PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0);
96  }
97 
98  {
99  /* write recipe.json */
100  rapidjson::StringBuffer os;
101  rapidjson::Writer<rapidjson::StringBuffer> writer(os);
102 
103  writer.StartObject();
104  writer.Key("target");
105  writer.String(param_.native_lib_name.data(), param_.native_lib_name.size());
106  writer.Key("sources");
107  writer.StartArray();
108  for (const auto& kv : files_) {
109  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
110  const size_t line_count
111  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
112  writer.StartObject();
113  writer.Key("name");
114  std::string name = kv.first.substr(0, kv.first.length() - 2);
115  writer.String(name.data(), name.size());
116  writer.Key("length");
117  writer.Uint64(line_count);
118  writer.EndObject();
119  }
120  }
121  writer.EndArray();
122  writer.EndObject();
123 
124  files_["recipe.json"] = CompiledModel::FileEntry(os.GetString());
125  }
126  cm.files = std::move(files_);
127  return cm;
128  }
129 
130  CompiledModel Compile(const Model& model) {
131  TREELITE_CHECK(model.GetLeafOutputType() != TypeInfo::kUInt32)
132  << "Integer leaf outputs not yet supported";
133  this->pred_tranform_func_ = PredTransformFunction("native", model);
134  return model.Dispatch([this](const auto& model_handle) {
135  return this->CompileImpl(model_handle);
136  });
137  }
138 
139  CompilerParam QueryParam() const {
140  return param_;
141  }
142 
143  private:
144  CompilerParam param_;
145  int num_feature_;
146  TaskType task_type_;
147  TaskParam task_param_;
148  std::string pred_transform_;
149  float sigmoid_alpha_;
150  float global_bias_;
151  std::string pred_tranform_func_;
152  std::string array_is_categorical_;
153  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
154 
155  template <typename ThresholdType, typename LeafOutputType>
156  void WalkAST(const ASTNode* node,
157  const std::string& dest,
158  size_t indent) {
159  const MainNode* t1;
160  const AccumulatorContextNode* t2;
161  const ConditionNode* t3;
162  const OutputNode<LeafOutputType>* t4;
163  const TranslationUnitNode* t5;
165  const CodeFolderNode* t7;
166  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
167  HandleMainNode<ThresholdType, LeafOutputType>(t1, dest, indent);
168  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
169  HandleACNode<ThresholdType, LeafOutputType>(t2, dest, indent);
170  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
171  HandleCondNode<ThresholdType, LeafOutputType>(t3, dest, indent);
172  } else if ( (t4 = dynamic_cast<const OutputNode<LeafOutputType>*>(node)) ) {
173  HandleOutputNode<ThresholdType, LeafOutputType>(t4, dest, indent);
174  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
175  HandleTUNode<ThresholdType, LeafOutputType>(t5, dest, indent);
176  } else if ( (t6 = dynamic_cast<const QuantizerNode<ThresholdType>*>(node)) ) {
177  HandleQNode<ThresholdType, LeafOutputType>(t6, dest, indent);
178  } else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
179  HandleCodeFolderNode<ThresholdType, LeafOutputType>(t7, dest, indent);
180  } else {
181  TREELITE_LOG(FATAL) << "Unrecognized AST node type";
182  }
183  }
184 
185  // append content to a given buffer, with given level of indentation
186  inline void AppendToBuffer(const std::string& dest,
187  const std::string& content,
188  size_t indent) {
189  files_[dest].content += common_util::IndentMultiLineString(content, indent);
190  }
191 
192  // prepend content to a given buffer, with given level of indentation
193  inline void PrependToBuffer(const std::string& dest,
194  const std::string& content,
195  size_t indent) {
196  files_[dest].content
197  = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
198  }
199 
200  template <typename ThresholdType, typename LeafOutputType>
201  void HandleMainNode(const MainNode* node,
202  const std::string& dest,
203  size_t indent) {
204  const std::string threshold_type
205  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
206  const std::string leaf_output_type
207  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
208  const std::string predict_function_signature
209  = (task_param_.num_class > 1) ?
210  fmt::format("size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)",
211  leaf_output_type)
212  : fmt::format("{} predict(union Entry* data, int pred_margin)",
213  leaf_output_type);
214 
215  if (!array_is_categorical_.empty()) {
216  array_is_categorical_
217  = fmt::format("const unsigned char is_categorical[] = {{\n{}\n}}",
218  array_is_categorical_);
219  }
220 
221  const std::string query_functions_definition
222  = fmt::format(native::query_functions_definition_template,
223  "num_class"_a = task_param_.num_class,
224  "num_feature"_a = num_feature_,
225  "pred_transform"_a = pred_transform_,
226  "sigmoid_alpha"_a = sigmoid_alpha_,
227  "global_bias"_a = global_bias_,
228  "threshold_type_str"_a = TypeInfoToString(TypeToInfo<ThresholdType>()),
229  "leaf_output_type_str"_a = TypeInfoToString(TypeToInfo<LeafOutputType>()));
230 
231  AppendToBuffer(dest,
232  fmt::format(native::main_start_template,
233  "array_is_categorical"_a = array_is_categorical_,
234  "query_functions_definition"_a = query_functions_definition,
235  "pred_transform_function"_a = pred_tranform_func_,
236  "predict_function_signature"_a = predict_function_signature),
237  indent);
238  const std::string query_functions_prototype
239  = fmt::format(native::query_functions_prototype_template,
240  "dllexport"_a = DLLEXPORT_KEYWORD);
241  AppendToBuffer("header.h",
242  fmt::format(native::header_template,
243  "dllexport"_a = DLLEXPORT_KEYWORD,
244  "predict_function_signature"_a = predict_function_signature,
245  "query_functions_prototype"_a = query_functions_prototype,
246  "threshold_type"_a = threshold_type,
247  "threshold_type_Node"_a = (param_.quantize > 0 ? std::string("int") : threshold_type)),
248  indent);
249 
250  TREELITE_CHECK_EQ(node->children.size(), 1);
251  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
252 
253  std::string optional_average_field;
254  if (node->average_result) {
255  if (task_type_ == TaskType::kMultiClfGrovePerClass) {
256  TREELITE_CHECK(task_param_.grove_per_class);
257  TREELITE_CHECK_EQ(task_param_.leaf_vector_size, 1);
258  TREELITE_CHECK_GT(task_param_.num_class, 1);
259  TREELITE_CHECK_EQ(node->num_tree % task_param_.num_class, 0)
260  << "Expected the number of trees to be divisible by the number of classes";
261  int num_boosting_round = node->num_tree / static_cast<int>(task_param_.num_class);
262  optional_average_field = fmt::format(" / {}", num_boosting_round);
263  } else {
264  TREELITE_CHECK(task_type_ == TaskType::kBinaryClfRegr
265  || task_type_ == TaskType::kMultiClfProbDistLeaf);
266  TREELITE_CHECK_EQ(task_param_.num_class, task_param_.leaf_vector_size);
267  TREELITE_CHECK(!task_param_.grove_per_class);
268  optional_average_field = fmt::format(" / {}", node->num_tree);
269  }
270  }
271  if (task_param_.num_class > 1) {
272  AppendToBuffer(dest,
273  fmt::format(native::main_end_multiclass_template,
274  "num_class"_a = task_param_.num_class,
275  "optional_average_field"_a = optional_average_field,
276  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
277  "leaf_output_type"_a = leaf_output_type),
278  indent);
279  } else {
280  AppendToBuffer(dest,
281  fmt::format(native::main_end_template,
282  "optional_average_field"_a = optional_average_field,
283  "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
284  "leaf_output_type"_a = leaf_output_type),
285  indent);
286  }
287  }
288 
289  template <typename ThresholdType, typename LeafOutputType>
290  void HandleACNode(const AccumulatorContextNode* node,
291  const std::string& dest,
292  size_t indent) {
293  const std::string leaf_output_type
294  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
295  if (task_param_.num_class > 1) {
296  AppendToBuffer(dest,
297  fmt::format("{leaf_output_type} sum[{num_class}] = {{0}};\n"
298  "unsigned int tmp;\n"
299  "int nid, cond, fid; /* used for folded subtrees */\n",
300  "num_class"_a = task_param_.num_class,
301  "leaf_output_type"_a = leaf_output_type), indent);
302  } else {
303  AppendToBuffer(dest,
304  fmt::format("{leaf_output_type} sum = ({leaf_output_type})0;\n"
305  "unsigned int tmp;\n"
306  "int nid, cond, fid; /* used for folded subtrees */\n",
307  "leaf_output_type"_a = leaf_output_type),
308  indent);
309  }
310  for (ASTNode* child : node->children) {
311  WalkAST<ThresholdType, LeafOutputType>(child, dest, indent);
312  }
313  }
314 
315  template <typename ThresholdType, typename LeafOutputType>
316  void HandleCondNode(const ConditionNode* node,
317  const std::string& dest,
318  size_t indent) {
320  std::string condition_with_na_check;
321  if ( (t = dynamic_cast<const NumericalConditionNode<ThresholdType>*>(node)) ) {
322  /* numerical split */
323  std::string condition = ExtractNumericalCondition(t);
324  const char* condition_with_na_check_template
325  = (node->default_left) ?
326  "!(data[{split_index}].missing != -1) || ({condition})"
327  : " (data[{split_index}].missing != -1) && ({condition})";
328  condition_with_na_check
329  = fmt::format(condition_with_na_check_template,
330  "split_index"_a = node->split_index,
331  "condition"_a = condition);
332  } else { /* categorical split */
333  const CategoricalConditionNode* t2 = dynamic_cast<const CategoricalConditionNode*>(node);
334  TREELITE_CHECK(t2);
335  condition_with_na_check = ExtractCategoricalCondition(t2);
336  }
337  if (node->children[0]->data_count && node->children[1]->data_count) {
338  const uint64_t left_freq = *node->children[0]->data_count;
339  const uint64_t right_freq = *node->children[1]->data_count;
340  condition_with_na_check
341  = fmt::format(" {keyword}( {condition} ) ",
342  "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"),
343  "condition"_a = condition_with_na_check);
344  }
345  AppendToBuffer(dest,
346  fmt::format("if ({}) {{\n", condition_with_na_check), indent);
347  TREELITE_CHECK_EQ(node->children.size(), 2);
348  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
349  AppendToBuffer(dest, "} else {\n", indent);
350  WalkAST<ThresholdType, LeafOutputType>(node->children[1], dest, indent + 2);
351  AppendToBuffer(dest, "}\n", indent);
352  }
353 
354  template <typename ThresholdType, typename LeafOutputType>
355  void HandleOutputNode(const OutputNode<LeafOutputType>* node,
356  const std::string& dest,
357  size_t indent) {
358  AppendToBuffer(dest, RenderOutputStatement(node), indent);
359  TREELITE_CHECK_EQ(node->children.size(), 0);
360  }
361 
362  template <typename ThresholdType, typename LeafOutputType>
363  void HandleTUNode(const TranslationUnitNode* node,
364  const std::string& dest,
365  size_t indent) {
366  const int unit_id = node->unit_id;
367  const std::string new_file = fmt::format("tu{}.c", unit_id);
368  const std::string leaf_output_type
369  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
370 
371  std::string unit_function_name, unit_function_signature,
372  unit_function_call_signature;
373  if (task_param_.num_class > 1) {
374  unit_function_name
375  = fmt::format("predict_margin_multiclass_unit{}", unit_id);
376  unit_function_signature
377  = fmt::format("void {function_name}(union Entry* data, {leaf_output_type}* result)",
378  "function_name"_a = unit_function_name,
379  "leaf_output_type"_a = leaf_output_type);
380  unit_function_call_signature
381  = fmt::format("{}(data, sum);\n", unit_function_name);
382  } else {
383  unit_function_name
384  = fmt::format("predict_margin_unit{}", unit_id);
385  unit_function_signature
386  = fmt::format("{leaf_output_type} {function_name}(union Entry* data)",
387  "function_name"_a = unit_function_name,
388  "leaf_output_type"_a = leaf_output_type);
389  unit_function_call_signature
390  = fmt::format("sum += {}(data);\n", unit_function_name);
391  }
392  AppendToBuffer(dest, unit_function_call_signature, indent);
393  AppendToBuffer(new_file,
394  fmt::format("#include \"header.h\"\n"
395  "{} {{\n", unit_function_signature), 0);
396  TREELITE_CHECK_EQ(node->children.size(), 1);
397  WalkAST<ThresholdType, LeafOutputType>(node->children[0], new_file, 2);
398  if (task_param_.num_class > 1) {
399  AppendToBuffer(new_file,
400  fmt::format(" for (int i = 0; i < {num_class}; ++i) {{\n"
401  " result[i] += sum[i];\n"
402  " }}\n"
403  "}}\n",
404  "num_class"_a = task_param_.num_class), 0);
405  } else {
406  AppendToBuffer(new_file, " return sum;\n}\n", 0);
407  }
408  AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0);
409  }
410 
411  template <typename ThresholdType, typename LeafOutputType>
412  void HandleQNode(const QuantizerNode<ThresholdType>* node,
413  const std::string& dest,
414  size_t indent) {
415  const std::string threshold_type
416  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
417  /* render arrays needed to convert feature values into bin indices */
418  std::string array_threshold, array_th_begin, array_th_len;
419  // threshold[] : list of all thresholds that occur at least once in the
420  // ensemble model. For each feature, an ascending list of unique
421  // thresholds is generated. The range th_begin[i]:(th_begin[i]+th_len[i])
422  // of the threshold[] array stores the threshold list for feature i.
423  size_t total_num_threshold;
424  // to hold total number of (distinct) thresholds
425  {
426  common_util::ArrayFormatter formatter(80, 2);
427  for (const auto& e : node->cut_pts) {
428  // cut_pts had been generated in ASTBuilder::QuantizeThresholds
429  // cut_pts[i][k] stores the k-th threshold of feature i.
430  for (auto v : e) {
431  formatter << v;
432  }
433  }
434  array_threshold = formatter.str();
435  }
436  {
437  common_util::ArrayFormatter formatter(80, 2);
438  size_t accum = 0; // used to compute cumulative sum over threshold counts
439  for (const auto& e : node->cut_pts) {
440  formatter << accum;
441  accum += e.size(); // e.size() = number of thresholds for each feature
442  }
443  total_num_threshold = accum;
444  array_th_begin = formatter.str();
445  }
446  {
447  common_util::ArrayFormatter formatter(80, 2);
448  for (const auto& e : node->cut_pts) {
449  formatter << e.size();
450  }
451  array_th_len = formatter.str();
452  }
453  if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
454  PrependToBuffer(dest,
455  fmt::format(native::qnode_template,
456  "total_num_threshold"_a = total_num_threshold,
457  "threshold_type"_a = threshold_type),
458  0);
459  AppendToBuffer(dest,
460  fmt::format(native::quantize_loop_template,
461  "num_feature"_a = num_feature_), indent);
462  }
463  if (!array_threshold.empty()) {
464  PrependToBuffer(dest,
465  fmt::format("static const {threshold_type} threshold[] = {{\n"
466  "{array_threshold}\n"
467  "}};\n",
468  "array_threshold"_a = array_threshold,
469  "threshold_type"_a = threshold_type),
470  0);
471  }
472  if (!array_th_begin.empty()) {
473  PrependToBuffer(dest,
474  fmt::format("static const int th_begin[] = {{\n"
475  "{array_th_begin}\n"
476  "}};\n", "array_th_begin"_a = array_th_begin), 0);
477  }
478  if (!array_th_len.empty()) {
479  PrependToBuffer(dest,
480  fmt::format("static const int th_len[] = {{\n"
481  "{array_th_len}\n"
482  "}};\n", "array_th_len"_a = array_th_len), 0);
483  }
484  TREELITE_CHECK_EQ(node->children.size(), 1);
485  WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent);
486  }
487 
488  template <typename ThresholdType, typename LeafOutputType>
489  void HandleCodeFolderNode(const CodeFolderNode* node,
490  const std::string& dest,
491  size_t indent) {
492  TREELITE_CHECK_EQ(node->children.size(), 1);
493  const int node_id = node->children[0]->node_id;
494  const int tree_id = node->children[0]->tree_id;
495 
496  /* render arrays needed for folding subtrees */
497  std::string array_nodes, array_cat_bitmap, array_cat_begin;
498  // node_treeXX_nodeXX[] : information of nodes for a particular subtree
499  const std::string node_array_name
500  = fmt::format("node_tree{}_node{}", tree_id, node_id);
501  // cat_bitmap_treeXX_nodeXX[] : list of all 64-bit integer bitmaps, used to
502  // make all categorical splits in a particular
503  // subtree
504  const std::string cat_bitmap_name
505  = fmt::format("cat_bitmap_tree{}_node{}", tree_id, node_id);
506  // cat_begin_treeXX_nodeXX[] : shows which bitmaps belong to each split.
507  // cat_bitmap[ cat_begin[i]:cat_begin[i+1] ]
508  // belongs to the i-th (categorical) split
509  const std::string cat_begin_name
510  = fmt::format("cat_begin_tree{}_node{}", tree_id, node_id);
511 
512  std::string output_switch_statement;
513  Operator common_comp_op;
514  common_util::RenderCodeFolderArrays<ThresholdType, LeafOutputType>(node, param_.quantize,
515  false, "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
516  [this](const OutputNode<LeafOutputType>* node) { return RenderOutputStatement(node); },
517  &array_nodes, &array_cat_bitmap, &array_cat_begin, &output_switch_statement,
518  &common_comp_op);
519  if (!array_nodes.empty()) {
520  AppendToBuffer("header.h",
521  fmt::format("extern const struct Node {node_array_name}[];\n",
522  "node_array_name"_a = node_array_name), 0);
523  AppendToBuffer("arrays.c",
524  fmt::format("const struct Node {node_array_name}[] = {{\n"
525  "{array_nodes}\n"
526  "}};\n",
527  "node_array_name"_a = node_array_name,
528  "array_nodes"_a = array_nodes), 0);
529  }
530 
531  if (!array_cat_bitmap.empty()) {
532  AppendToBuffer("header.h",
533  fmt::format("extern const uint64_t {cat_bitmap_name}[];\n",
534  "cat_bitmap_name"_a = cat_bitmap_name), 0);
535  AppendToBuffer("arrays.c",
536  fmt::format("const uint64_t {cat_bitmap_name}[] = {{\n"
537  "{array_cat_bitmap}\n"
538  "}};\n",
539  "cat_bitmap_name"_a = cat_bitmap_name,
540  "array_cat_bitmap"_a = array_cat_bitmap), 0);
541  }
542 
543  if (!array_cat_begin.empty()) {
544  AppendToBuffer("header.h",
545  fmt::format("extern const size_t {cat_begin_name}[];\n",
546  "cat_begin_name"_a = cat_begin_name), 0);
547  AppendToBuffer("arrays.c",
548  fmt::format("const size_t {cat_begin_name}[] = {{\n"
549  "{array_cat_begin}\n"
550  "}};\n",
551  "cat_begin_name"_a = cat_begin_name,
552  "array_cat_begin"_a = array_cat_begin), 0);
553  }
554 
555  if (array_nodes.empty()) {
556  /* folded code consists of a single leaf node */
557  AppendToBuffer(dest,
558  fmt::format("nid = -1;\n"
559  "{output_switch_statement}\n",
560  "output_switch_statement"_a
561  = output_switch_statement), indent);
562  } else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
563  AppendToBuffer(dest,
564  fmt::format(native::eval_loop_template,
565  "node_array_name"_a = node_array_name,
566  "cat_bitmap_name"_a = cat_bitmap_name,
567  "cat_begin_name"_a = cat_begin_name,
568  "data_field"_a = (param_.quantize > 0 ? "qvalue" : "fvalue"),
569  "comp_op"_a = OpName(common_comp_op),
570  "output_switch_statement"_a
571  = output_switch_statement), indent);
572  } else {
573  AppendToBuffer(dest,
574  fmt::format(native::eval_loop_template_without_categorical_feature,
575  "node_array_name"_a = node_array_name,
576  "data_field"_a = (param_.quantize > 0 ? "qvalue" : "fvalue"),
577  "comp_op"_a = OpName(common_comp_op),
578  "output_switch_statement"_a
579  = output_switch_statement), indent);
580  }
581  }
582 
583  template <typename ThresholdType>
584  inline std::string
585  ExtractNumericalCondition(const NumericalConditionNode<ThresholdType>* node) {
586  const std::string threshold_type
587  = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
588  std::string result;
589  if (node->quantized) { // quantized threshold
590  std::string lhs = fmt::format("data[{split_index}].qvalue",
591  "split_index"_a = node->split_index);
592  result = fmt::format("{lhs} {opname} {threshold}",
593  "lhs"_a = lhs,
594  "opname"_a = OpName(node->op),
595  "threshold"_a = node->threshold.int_val);
596  } else if (std::isinf(node->threshold.float_val)) { // infinite threshold
597  // According to IEEE 754, the result of comparison [lhs] < infinity
598  // must be identical for all finite [lhs]. Same goes for operator >.
599  result = (CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
600  ? "1" : "0");
601  } else { // finite threshold
602  std::string lhs = fmt::format("data[{split_index}].fvalue",
603  "split_index"_a = node->split_index);
604  result
605  = fmt::format("{lhs} {opname} ({threshold_type}){threshold}",
606  "lhs"_a = lhs,
607  "opname"_a = OpName(node->op),
608  "threshold_type"_a = threshold_type,
609  "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
610  }
611  return result;
612  }
613 
614  inline std::string
615  ExtractCategoricalCondition(const CategoricalConditionNode* node) {
616  std::string result;
617  std::vector<uint64_t> bitmap
618  = common_util::GetCategoricalBitmap(node->matching_categories);
619  TREELITE_CHECK_GE(bitmap.size(), 1);
620  bool all_zeros = true;
621  for (uint64_t e : bitmap) {
622  all_zeros &= (e == 0);
623  }
624  if (all_zeros) {
625  result = "0";
626  } else {
627  std::ostringstream oss;
628  const std::string right_categories_flag = (node->categories_list_right_child ? "!" : "");
629  if (node->default_left) {
630  oss << fmt::format(
631  "data[{split_index}].missing == -1 || {right_categories_flag}("
632  "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
633  "split_index"_a = node->split_index,
634  "right_categories_flag"_a = right_categories_flag);
635  } else {
636  oss << fmt::format(
637  "data[{split_index}].missing != -1 && {right_categories_flag}("
638  "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
639  "split_index"_a = node->split_index,
640  "right_categories_flag"_a = right_categories_flag);
641  }
642  oss << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
643  << bitmap[0] << "U >> tmp) & 1) )";
644  for (size_t i = 1; i < bitmap.size(); ++i) {
645  oss << " || (tmp >= " << (i * 64)
646  << " && tmp < " << ((i + 1) * 64)
647  << " && (( (uint64_t)" << bitmap[i]
648  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
649  }
650  oss << ")";
651  result = oss.str();
652  }
653  return result;
654  }
655 
656  inline std::string
657  RenderIsCategoricalArray(const std::vector<bool>& is_categorical) {
658  common_util::ArrayFormatter formatter(80, 2);
659  for (int fid = 0; fid < num_feature_; ++fid) {
660  formatter << (is_categorical[fid] ? 1 : 0);
661  }
662  return formatter.str();
663  }
664 
665  template <typename LeafOutputType>
666  inline std::string RenderOutputStatement(const OutputNode<LeafOutputType>* node) {
667  const std::string leaf_output_type
668  = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
669  std::string output_statement;
670  if (task_param_.num_class > 1) {
671  if (node->is_vector) {
672  // multi-class classification with random forest
673  TREELITE_CHECK_EQ(node->vector.size(), static_cast<size_t>(task_param_.num_class))
674  << "Ill-formed model: leaf vector must be of length [num_class]";
675  for (size_t group_id = 0; group_id < task_param_.num_class; ++group_id) {
676  output_statement
677  += fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n",
678  "group_id"_a = group_id,
679  "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]),
680  "leaf_output_type"_a = leaf_output_type);
681  }
682  } else {
683  // multi-class classification with gradient boosted trees
684  output_statement
685  = fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n",
686  "group_id"_a = node->tree_id % task_param_.num_class,
687  "output"_a = common_util::ToStringHighPrecision(node->scalar),
688  "leaf_output_type"_a = leaf_output_type);
689  }
690  } else {
691  output_statement
692  = fmt::format("sum += ({leaf_output_type}){output};\n",
693  "output"_a = common_util::ToStringHighPrecision(node->scalar),
694  "leaf_output_type"_a = leaf_output_type);
695  }
696  return output_statement;
697  }
698 };
699 
700 ASTNativeCompiler::ASTNativeCompiler(const CompilerParam& param)
701  : pimpl_(std::make_unique<ASTNativeCompilerImpl>(param)) {
702  if (param.verbose > 0) {
703  TREELITE_LOG(INFO) << "Using ASTNativeCompiler";
704  }
705  if (param.dump_array_as_elf > 0) {
706  TREELITE_LOG(INFO) << "Warning: 'dump_array_as_elf' parameter is not applicable "
707  "for ASTNativeCompiler";
708  }
709 }
710 
711 ASTNativeCompiler::~ASTNativeCompiler() = default;
712 
714 ASTNativeCompiler::Compile(const Model &model) {
715  return pimpl_->Compile(model);
716 }
717 
719 ASTNativeCompiler::QueryParam() const {
720  return pimpl_->QueryParam();
721 }
722 
723 } // namespace compiler
724 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:702
Parameters for tree compiler.
branch annotator class
Definition: annotator.h:21
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:183
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:172
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:43
TaskType
Enum type representing the task type.
Definition: tree.h:98
parameters for tree compiler
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float&#39;s using a comparsion operator The comparison will be in the form...
Definition: base.h:77
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:198
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:191
template for header
TaskType task_type
Task type.
Definition: tree.h:696
Interface of compiler that compiles a tree ensemble model.
template for main function
std::vector< std::vector< uint64_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:55
template for evaluation logic for folded code
code template for QuantizerNode
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Look up C symbols corresponding to TypeInfo.
Utilities for code folding.
Definition: compiler.h:26
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
AST Builder class.
std::string str()
obtain formatted text containing the rendered array
Definition: format_util.h:99
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:700
C code generator.
Branch annotation tools.
thin wrapper for tree ensemble model
Definition: tree.h:647
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
void Load(std::istream &fi)
load branch annotation from a JSON file
Definition: annotator.cc:238
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:694
Operator
comparison operators
Definition: base.h:26