treelite
ast_native.cc
1 #include <treelite/compiler.h>
2 #include <treelite/common.h>
3 #include <treelite/annotator.h>
4 #include <unordered_map>
5 #include <cmath>
6 #include "./param.h"
7 #include "./pred_transform.h"
8 #include "./ast/builder.h"
9 #include "./native/get_num_feature.h"
10 #include "./native/get_num_output_group.h"
11 
12 namespace treelite {
13 namespace compiler {
14 
15 DMLC_REGISTRY_FILE_TAG(ast_native);
16 
17 class ASTNativeCompiler : public Compiler {
18  public:
19  explicit ASTNativeCompiler(const CompilerParam& param)
20  : param(param) {
21  if (param.verbose > 0) {
22  LOG(INFO) << "Using ASTNativeCompiler";
23  }
24  }
25 
26  CompiledModel Compile(const Model& model) override {
27  CompiledModel cm;
28  cm.backend = "native";
29  cm.files["main.c"] = "";
30 
31  num_output_group_ = model.num_output_group;
32  pred_tranform_func_ = PredTransformFunction("native", model);
33  files_.clear();
34 
35  ASTBuilder builder;
36  builder.Build(model);
37  if (param.annotate_in != "NULL") {
38  BranchAnnotator annotator;
39  std::unique_ptr<dmlc::Stream> fi(
40  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
41  annotator.Load(fi.get());
42  const auto annotation = annotator.Get();
43  builder.AnnotateBranches(annotation);
44  LOG(INFO) << "Using branch annotation file `"
45  << param.annotate_in << "'";
46  }
47  builder.Split(param.parallel_comp);
48  if (param.quantize > 0) {
49  builder.QuantizeThresholds();
50  }
51  #include "./native/header.h"
52  files_["header.h"] = header;
53  WalkAST(builder.GetRootNode(), "main.c", 0);
54 
55  {
56  /* write recipe.json */
57  std::vector<std::unordered_map<std::string, std::string>> source_list;
58  for (auto kv : files_) {
59  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
60  const size_t line_count
61  = std::count(kv.second.begin(), kv.second.end(), '\n');
62  source_list.push_back({ {"name",
63  kv.first.substr(0, kv.first.length() - 2)},
64  {"length", std::to_string(line_count)} });
65  }
66  }
67  std::ostringstream oss;
68  auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
69  writer->BeginObject();
70  writer->WriteObjectKeyValue("target", std::string("predictor"));
71  writer->WriteObjectKeyValue("sources", source_list);
72  writer->EndObject();
73  files_["recipe.json"] = oss.str();
74  }
75  cm.files = std::move(files_);
76  return cm;
77  }
78  private:
79  CompilerParam param;
80  int num_output_group_;
81  std::string pred_tranform_func_;
82  std::unordered_map<std::string, std::string> files_;
83 
84  void WalkAST(const ASTNode* node,
85  const std::string& dest,
86  int indent) {
87  const MainNode* t1;
88  const AccumulatorContextNode* t2;
89  const ConditionNode* t3;
90  const OutputNode* t4;
91  const TranslationUnitNode* t5;
92  const QuantizerNode* t6;
93  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
94  HandleMainNode(t1, dest, indent);
95  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
96  HandleACNode(t2, dest, indent);
97  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
98  HandleCondNode(t3, dest, indent);
99  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
100  HandleOutputNode(t4, dest, indent);
101  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
102  HandleTUNode(t5, dest, indent);
103  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
104  HandleQNode(t6, dest, indent);
105  } else {
106  LOG(FATAL) << "oops";
107  }
108  }
109 
110  void HandleMainNode(const MainNode* node,
111  const std::string& dest,
112  int indent) {
113  const std::string prototype
114  = (num_output_group_ > 1) ?
115  "size_t predict_multiclass(union Entry* data, int pred_margin, "
116  "float* result)"
117  : "float predict(union Entry* data, int pred_margin)";
118  files_[dest] += std::string(indent, ' ') + "#include \"header.h\"\n\n";
119  files_[dest] += get_num_output_group_func(num_output_group_) + "\n"
120  + get_num_feature_func(node->num_feature) + "\n"
121  + pred_tranform_func_ + "\n"
122  + std::string(indent, ' ') + prototype + " {\n";
123  files_["header.h"] += get_num_output_group_func_prototype();
124  files_["header.h"] += get_num_feature_func_prototype();
125  files_["header.h"] += prototype + ";\n";
126  CHECK_EQ(node->children.size(), 1);
127  WalkAST(node->children[0], dest, indent + 2);
128  std::ostringstream oss;
129  if (num_output_group_ > 1) {
130  oss << std::string(indent + 2, ' ')
131  << "for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
132  << std::string(indent + 4, ' ') << "result[i] = sum[i]";
133  if (node->average_result) {
134  oss << " / " << node->num_tree;
135  }
136  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
137  << std::string(indent + 2, ' ') << "}\n"
138  << std::string(indent + 2, ' ') << "if (!pred_margin) {\n"
139  << std::string(indent + 2, ' ')
140  << " return pred_transform(result);\n"
141  << std::string(indent + 2, ' ') << "} else {\n"
142  << std::string(indent + 2, ' ')
143  << " return " << num_output_group_ << ";\n"
144  << std::string(indent + 2, ' ') << "}\n";
145  } else {
146  oss << std::string(indent + 2, ' ') << "sum = sum";
147  if (node->average_result) {
148  oss << " / " << node->num_tree;
149  }
150  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
151  << std::string(indent + 2, ' ') << "if (!pred_margin) {\n"
152  << std::string(indent + 2, ' ') << " return pred_transform(sum);\n"
153  << std::string(indent + 2, ' ') << "} else {\n"
154  << std::string(indent + 2, ' ') << " return sum;\n"
155  << std::string(indent + 2, ' ') << "}\n";
156  }
157  oss << std::string(indent, ' ') << "}\n";
158  files_[dest] += oss.str();
159  }
160 
161  void HandleACNode(const AccumulatorContextNode* node,
162  const std::string& dest,
163  int indent) {
164  std::ostringstream oss;
165  if (num_output_group_ > 1) {
166  oss << std::string(indent, ' ')
167  << "float sum[" << num_output_group_ << "] = {0.0f};\n";
168  } else {
169  oss << std::string(indent, ' ') << "float sum = 0.0f;\n";
170  }
171  oss << std::string(indent, ' ') << "unsigned int tmp;\n";
172  files_[dest] += oss.str();
173  for (ASTNode* child : node->children) {
174  WalkAST(child, dest, indent);
175  }
176  }
177 
178  void HandleCondNode(const ConditionNode* node,
179  const std::string& dest,
180  int indent) {
181  const unsigned split_index = node->split_index;
182  const std::string na_check
183  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
184 
185  const NumericalConditionNode* t;
186  std::ostringstream oss; // prepare logical statement for evaluating split
187  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
188  if (t->quantized) { // quantized threshold
189  oss << "data[" << split_index << "].qvalue "
190  << OpName(t->op) << " " << t->threshold.int_val;
191  } else if (std::isinf(t->threshold.float_val)) { // infinite threshold
192  // According to IEEE 754, the result of comparison [lhs] < infinity
193  // must be identical for all finite [lhs]. Same goes for operator >.
194  oss << (common::CompareWithOp(0.0, t->op, t->threshold.float_val)
195  ? "1" : "0");
196  } else { // finite threshold
197  // to restore default precision
198  const std::streamsize ss = std::cout.precision();
199  oss << "data[" << split_index << "].fvalue "
200  << OpName(t->op) << " "
201  << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
202  << t->threshold.float_val << std::setprecision(ss);
203  }
204  } else { // categorical split
205  const CategoricalConditionNode* t2
206  = dynamic_cast<const CategoricalConditionNode*>(node);
207  CHECK(t2);
208  std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
209  CHECK_GE(bitmap.size(), 1);
210  bool all_zeros = true;
211  for (uint64_t e : bitmap) {
212  all_zeros &= (e == 0);
213  }
214  if (all_zeros) {
215  oss << "0";
216  } else {
217  oss << "(tmp = (unsigned int)(data[" << split_index << "].fvalue) ), "
218  << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
219  << bitmap[0] << "U >> tmp) & 1) )";
220  for (size_t i = 1; i < bitmap.size(); ++i) {
221  oss << " || (tmp >= " << (i * 64)
222  << " && tmp < " << ((i + 1) * 64)
223  << " && (( (uint64_t)" << bitmap[i]
224  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
225  }
226  }
227  }
228  std::string cond
229  = ((node->default_left) ? (std::string("!(") + na_check + ") || (")
230  : (std::string(" (") + na_check + ") && ("))
231  + oss.str() + ")";
232  switch (node->branch_hint) {
233  case BranchHint::kLikely:
234  cond = std::string(" LIKELY( ") + cond + " ) ";
235  break;
236  case BranchHint::kUnlikely:
237  cond = std::string(" UNLIKELY( ") + cond + " ) ";
238  break;
239  }
240  files_[dest] += std::string(indent, ' ') + "if (" + cond + ") {\n";
241  CHECK_EQ(node->children.size(), 2);
242  WalkAST(node->children[0], dest, indent + 2);
243  files_[dest] += std::string(indent, ' ') + "} else {\n";
244  WalkAST(node->children[1], dest, indent + 2);
245  files_[dest] += std::string(indent, ' ') + "}\n";
246  }
247 
248  void HandleOutputNode(const OutputNode* node,
249  const std::string& dest,
250  int indent) {
251  std::ostringstream oss;
252  if (num_output_group_ > 1) {
253  if (node->is_vector) {
254  // multi-class classification with random forest
255  const std::vector<tl_float>& leaf_vector = node->vector;
256  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group_))
257  << "Ill-formed model: leaf vector must be of length [num_output_group]";
258  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
259  oss << std::string(indent, ' ')
260  << "sum[" << group_id << "] += (float)"
261  << common::ToString(leaf_vector[group_id]) << ";\n";
262  }
263  } else {
264  // multi-class classification with gradient boosted trees
265  oss << std::string(indent, ' ') << "sum["
266  << (node->tree_id % num_output_group_) << "] += (float)"
267  << common::ToString(node->scalar) << ";\n";
268  }
269  } else {
270  oss << std::string(indent, ' ') << "sum += (float)"
271  << common::ToString(node->scalar) << ";\n";
272  }
273  files_[dest] += oss.str();
274  CHECK_EQ(node->children.size(), 0);
275  }
276 
277  void HandleTUNode(const TranslationUnitNode* node,
278  const std::string& dest,
279  int indent) {
280  const std::string new_file
281  = std::string("tu") + std::to_string(node->unit_id) + ".c";
282  std::ostringstream caller_buf, callee_buf, func_name, prototype;
283  callee_buf << "#include \"header.h\"\n";
284  if (num_output_group_ > 1) {
285  func_name << "predict_margin_multiclass_unit" << node->unit_id;
286  caller_buf << std::string(indent, ' ')
287  << func_name.str() << "(data, sum);\n";
288  prototype << "void " << func_name.str()
289  << "(union Entry* data, float* result)";
290  } else {
291  func_name << "predict_margin_unit" << node->unit_id;
292  caller_buf << std::string(indent, ' ')
293  << "sum += " << func_name.str() << "(data);\n";
294  prototype << "float " << func_name.str() << "(union Entry* data)";
295  }
296  callee_buf << prototype.str() << " {\n";
297  files_[dest] += caller_buf.str();
298  files_[new_file] += callee_buf.str();
299  CHECK_EQ(node->children.size(), 1);
300  WalkAST(node->children[0], new_file, 2);
301  callee_buf.str(""); callee_buf.clear();
302  if (num_output_group_ > 1) {
303  callee_buf << " for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
304  << " result[i] = sum[i];\n"
305  << " }\n";
306  } else {
307  callee_buf << " return sum;\n";
308  }
309  callee_buf << "}\n";
310  files_[new_file] += callee_buf.str();
311  files_["header.h"] += prototype.str() + ";\n";
312  }
313 
314  void HandleQNode(const QuantizerNode* node,
315  const std::string& dest,
316  int indent) {
317  std::ostringstream oss; // prepare a preamble
318  const int num_feature = node->is_categorical.size();
319  size_t length = 2;
320  oss << "static const unsigned char is_categorical[] = {\n ";
321  for (int fid = 0; fid < num_feature; ++fid) {
322  if (node->is_categorical[fid]) {
323  common::WrapText(&oss, &length, "1, ", 2, 80);
324  } else {
325  common::WrapText(&oss, &length, "0, ", 2, 80);
326  }
327  }
328  oss << "\n};\n";
329  length = 2;
330  oss << "static const float threshold[] = {\n ";
331  for (const auto& e : node->cut_pts) {
332  for (tl_float v : e) {
333  common::WrapText(&oss, &length, common::ToString(v) + ", ", 2, 80);
334  }
335  }
336  oss << "\n};\n";
337  length = 2;
338  size_t accum = 0;
339  oss << "static const int th_begin[] = {\n ";
340  for (const auto& e : node->cut_pts) {
341  common::WrapText(&oss, &length, std::to_string(accum) + ", ", 2, 80);
342  accum += e.size();
343  }
344  oss << "\n};\n";
345  length = 2;
346  oss << "static const int th_len[] = {\n ";
347  for (const auto& e : node->cut_pts) {
348  common::WrapText(&oss, &length, std::to_string(e.size()) + ", ", 2, 80);
349  }
350  oss << "\n};\n";
351  #include "./native/quantize_func.h"
352  oss << quantize_func << files_[dest] << std::string(indent, ' ')
353  << "for (int i = 0; i < " << num_feature << "; ++i) {\n"
354  << std::string(indent + 2, ' ')
355  << "if (data[i].missing != -1 && !is_categorical[i]) {\n"
356  << std::string(indent + 4, ' ')
357  << "data[i].qvalue = quantize(data[i].fvalue, i);\n"
358  << std::string(indent + 2, ' ') + "}\n"
359  << std::string(indent, ' ') + "}\n";
360  files_[dest] = oss.str();
361  CHECK_EQ(node->children.size(), 1);
362  WalkAST(node->children[0], dest, indent);
363  }
364 
365  inline std::vector<uint64_t>
366  to_bitmap(const std::vector<uint32_t>& left_categories) const {
367  const size_t num_left_categories = left_categories.size();
368  if (num_left_categories == 0) {
369  return std::vector<uint64_t>{0};
370  }
371  const uint32_t max_left_category = left_categories[num_left_categories - 1];
372  std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
373  for (size_t i = 0; i < num_left_categories; ++i) {
374  const uint32_t cat = left_categories[i];
375  const size_t idx = cat / 64;
376  const uint32_t offset = cat % 64;
377  bitmap[idx] |= (static_cast<uint64_t>(1) << offset);
378  }
379  return bitmap;
380  }
381 };
382 
384 .describe("AST-based compiler that produces C code")
385 .set_body([](const CompilerParam& param) -> Compiler* {
386  return new ASTNativeCompiler(param);
387  });
388 } // namespace compiler
389 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
branch annotator class
Definition: annotator.h:16
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
parameters for tree compiler
Definition: param.h:16
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:32
Interface of compiler that compiles a tree ensemble model.
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:138
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:70
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:26
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
Branch annotation tools.
Some useful utilities.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
int verbose
if >0, produce extra messages
Definition: param.h:34