treelite
ast_native.cc
1 #include <treelite/compiler.h>
2 #include <treelite/annotator.h>
3 #include <unordered_map>
4 #include "./param.h"
5 #include "./pred_transform.h"
6 #include "./ast/builder.h"
7 #include "./native/get_num_feature.h"
8 #include "./native/get_num_output_group.h"
9 
10 namespace treelite {
11 namespace compiler {
12 
13 DMLC_REGISTRY_FILE_TAG(ast_native);
14 
15 class ASTNativeCompiler : public Compiler {
16  public:
17  explicit ASTNativeCompiler(const CompilerParam& param)
18  : param(param) {
19  if (param.verbose > 0) {
20  LOG(INFO) << "Using ASTNativeCompiler";
21  }
22  }
23 
24  CompiledModel Compile(const Model& model) override {
25  CompiledModel cm;
26  cm.backend = "native";
27  cm.files["main.c"] = "";
28 
29  num_output_group_ = model.num_output_group;
30  pred_tranform_func_ = PredTransformFunction("native", model);
31  files_.clear();
32 
33  ASTBuilder builder;
34  builder.Build(model);
35  if (param.annotate_in != "NULL") {
36  BranchAnnotator annotator;
37  std::unique_ptr<dmlc::Stream> fi(
38  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
39  annotator.Load(fi.get());
40  const auto annotation = annotator.Get();
41  builder.AnnotateBranches(annotation);
42  LOG(INFO) << "Using branch annotation file `"
43  << param.annotate_in << "'";
44  }
45  builder.Split(param.parallel_comp);
46  if (param.quantize > 0) {
47  builder.QuantizeThresholds();
48  }
49  #include "./native/header.h"
50  files_["header.h"] = header;
51  WalkAST(builder.GetRootNode(), "main.c", 0);
52 
53  {
54  /* write recipe.json */
55  std::vector<std::unordered_map<std::string, std::string>> source_list;
56  for (auto kv : files_) {
57  if (kv.first.compare(kv.first.length() - 2, 2, ".c") == 0) {
58  const size_t line_count
59  = std::count(kv.second.begin(), kv.second.end(), '\n');
60  source_list.push_back({ {"name",
61  kv.first.substr(0, kv.first.length() - 2)},
62  {"length", std::to_string(line_count)} });
63  }
64  }
65  std::ostringstream oss;
66  auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
67  writer->BeginObject();
68  writer->WriteObjectKeyValue("target", std::string("predictor"));
69  writer->WriteObjectKeyValue("sources", source_list);
70  writer->EndObject();
71  files_["recipe.json"] = oss.str();
72  }
73  cm.files = std::move(files_);
74  return cm;
75  }
76  private:
77  CompilerParam param;
78  int num_output_group_;
79  std::string pred_tranform_func_;
80  std::unordered_map<std::string, std::string> files_;
81 
82  void WalkAST(const ASTNode* node,
83  const std::string& dest,
84  int indent) {
85  const MainNode* t1;
86  const AccumulatorContextNode* t2;
87  const ConditionNode* t3;
88  const OutputNode* t4;
89  const TranslationUnitNode* t5;
90  const QuantizerNode* t6;
91  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
92  HandleMainNode(t1, dest, indent);
93  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
94  HandleACNode(t2, dest, indent);
95  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
96  HandleCondNode(t3, dest, indent);
97  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
98  HandleOutputNode(t4, dest, indent);
99  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
100  HandleTUNode(t5, dest, indent);
101  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
102  HandleQNode(t6, dest, indent);
103  } else {
104  LOG(FATAL) << "oops";
105  }
106  }
107 
108  void HandleMainNode(const MainNode* node,
109  const std::string& dest,
110  int indent) {
111  const std::string prototype
112  = (num_output_group_ > 1) ?
113  "size_t predict_multiclass(union Entry* data, int pred_margin, "
114  "float* result)"
115  : "float predict(union Entry* data, int pred_margin)";
116  files_[dest] += std::string(indent, ' ') + "#include \"header.h\"\n\n";
117  files_[dest] += get_num_output_group_func(num_output_group_) + "\n"
118  + get_num_feature_func(node->num_feature) + "\n"
119  + pred_tranform_func_ + "\n"
120  + std::string(indent, ' ') + prototype + " {\n";
121  files_["header.h"] += get_num_output_group_func_prototype();
122  files_["header.h"] += get_num_feature_func_prototype();
123  files_["header.h"] += prototype + ";\n";
124  CHECK_EQ(node->children.size(), 1);
125  WalkAST(node->children[0], dest, indent + 2);
126  std::ostringstream oss;
127  if (num_output_group_ > 1) {
128  oss << std::string(indent + 2, ' ')
129  << "for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
130  << std::string(indent + 4, ' ') << "result[i] = sum[i]";
131  if (node->average_result) {
132  oss << " / " << node->num_tree;
133  }
134  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
135  << std::string(indent + 2, ' ') << "}\n"
136  << std::string(indent + 2, ' ') << "if (!pred_margin) {\n"
137  << std::string(indent + 2, ' ')
138  << " return pred_transform(result);\n"
139  << std::string(indent + 2, ' ') << "} else {\n"
140  << std::string(indent + 2, ' ')
141  << " return " << num_output_group_ << ";\n"
142  << std::string(indent + 2, ' ') << "}\n";
143  } else {
144  oss << std::string(indent + 2, ' ') << "sum = sum";
145  if (node->average_result) {
146  oss << " / " << node->num_tree;
147  }
148  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
149  << std::string(indent + 2, ' ') << "if (!pred_margin) {\n"
150  << std::string(indent + 2, ' ') << " return pred_transform(sum);\n"
151  << std::string(indent + 2, ' ') << "} else {\n"
152  << std::string(indent + 2, ' ') << " return sum;\n"
153  << std::string(indent + 2, ' ') << "}\n";
154  }
155  oss << std::string(indent, ' ') << "}\n";
156  files_[dest] += oss.str();
157  }
158 
159  void HandleACNode(const AccumulatorContextNode* node,
160  const std::string& dest,
161  int indent) {
162  std::ostringstream oss;
163  if (num_output_group_ > 1) {
164  oss << std::string(indent, ' ')
165  << "float sum[" << num_output_group_ << "] = {0.0f};\n";
166  } else {
167  oss << std::string(indent, ' ') << "float sum = 0.0f;\n";
168  }
169  oss << std::string(indent, ' ') << "unsigned int tmp;\n";
170  files_[dest] += oss.str();
171  for (ASTNode* child : node->children) {
172  WalkAST(child, dest, indent);
173  }
174  }
175 
176  void HandleCondNode(const ConditionNode* node,
177  const std::string& dest,
178  int indent) {
179  const unsigned split_index = node->split_index;
180  const std::string na_check
181  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
182 
183  const NumericalConditionNode* t;
184  std::ostringstream oss; // prepare logical statement for evaluating split
185  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
186  if (t->quantized) {
187  oss << "data[" << split_index << "].qvalue "
188  << OpName(t->op) << " " << t->threshold.int_val;
189  } else {
190  // to restore default precision
191  const std::streamsize ss = std::cout.precision();
192  oss << "data[" << split_index << "].fvalue "
193  << OpName(t->op) << " "
194  << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
195  << t->threshold.float_val << std::setprecision(ss);
196  }
197  } else { // categorical split
198  const CategoricalConditionNode* t2
199  = dynamic_cast<const CategoricalConditionNode*>(node);
200  CHECK(t2);
201  std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
202  CHECK_GE(bitmap.size(), 1);
203  bool all_zeros = true;
204  for (uint64_t e : bitmap) {
205  all_zeros &= (e == 0);
206  }
207  if (all_zeros) {
208  oss << "0";
209  } else {
210  oss << "(tmp = (unsigned int)(data[" << split_index << "].fvalue) ), "
211  << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
212  << bitmap[0] << "U >> tmp) & 1) )";
213  for (size_t i = 1; i < bitmap.size(); ++i) {
214  oss << " || (tmp >= " << (i * 64)
215  << " && tmp < " << ((i + 1) * 64)
216  << " && (( (uint64_t)" << bitmap[i]
217  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
218  }
219  }
220  }
221  std::string cond
222  = ((node->default_left) ? (std::string("!(") + na_check + ") || (")
223  : (std::string(" (") + na_check + ") && ("))
224  + oss.str() + ")";
225  switch (node->branch_hint) {
226  case BranchHint::kLikely:
227  cond = std::string(" LIKELY( ") + cond + " ) ";
228  break;
229  case BranchHint::kUnlikely:
230  cond = std::string(" UNLIKELY( ") + cond + " ) ";
231  break;
232  }
233  files_[dest] += std::string(indent, ' ') + "if (" + cond + ") {\n";
234  CHECK_EQ(node->children.size(), 2);
235  WalkAST(node->children[0], dest, indent + 2);
236  files_[dest] += std::string(indent, ' ') + "} else {\n";
237  WalkAST(node->children[1], dest, indent + 2);
238  files_[dest] += std::string(indent, ' ') + "}\n";
239  }
240 
241  void HandleOutputNode(const OutputNode* node,
242  const std::string& dest,
243  int indent) {
244  std::ostringstream oss;
245  if (num_output_group_ > 1) {
246  if (node->is_vector) {
247  // multi-class classification with random forest
248  const std::vector<tl_float>& leaf_vector = node->vector;
249  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group_))
250  << "Ill-formed model: leaf vector must be of length [num_output_group]";
251  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
252  oss << std::string(indent, ' ')
253  << "sum[" << group_id << "] += (float)"
254  << common::ToString(leaf_vector[group_id]) << ";\n";
255  }
256  } else {
257  // multi-class classification with gradient boosted trees
258  oss << std::string(indent, ' ') << "sum["
259  << (node->tree_id % num_output_group_) << "] += (float)"
260  << common::ToString(node->scalar) << ";\n";
261  }
262  } else {
263  oss << std::string(indent, ' ') << "sum += (float)"
264  << common::ToString(node->scalar) << ";\n";
265  }
266  files_[dest] += oss.str();
267  CHECK_EQ(node->children.size(), 0);
268  }
269 
270  void HandleTUNode(const TranslationUnitNode* node,
271  const std::string& dest,
272  int indent) {
273  const std::string new_file
274  = std::string("tu") + std::to_string(node->unit_id) + ".c";
275  std::ostringstream caller_buf, callee_buf, func_name, prototype;
276  callee_buf << "#include \"header.h\"\n";
277  if (num_output_group_ > 1) {
278  func_name << "predict_margin_multiclass_unit" << node->unit_id;
279  caller_buf << std::string(indent, ' ')
280  << func_name.str() << "(data, sum);\n";
281  prototype << "void " << func_name.str()
282  << "(union Entry* data, float* result)";
283  } else {
284  func_name << "predict_margin_unit" << node->unit_id;
285  caller_buf << std::string(indent, ' ')
286  << "sum += " << func_name.str() << "(data);\n";
287  prototype << "float " << func_name.str() << "(union Entry* data)";
288  }
289  callee_buf << prototype.str() << " {\n";
290  files_[dest] += caller_buf.str();
291  files_[new_file] += callee_buf.str();
292  CHECK_EQ(node->children.size(), 1);
293  WalkAST(node->children[0], new_file, 2);
294  callee_buf.str(""); callee_buf.clear();
295  if (num_output_group_ > 1) {
296  callee_buf << " for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
297  << " result[i] = sum[i];\n"
298  << " }\n";
299  } else {
300  callee_buf << " return sum;\n";
301  }
302  callee_buf << "}\n";
303  files_[new_file] += callee_buf.str();
304  files_["header.h"] += prototype.str() + ";\n";
305  }
306 
307  void HandleQNode(const QuantizerNode* node,
308  const std::string& dest,
309  int indent) {
310  std::ostringstream oss; // prepare a preamble
311  const int num_feature = node->is_categorical.size();
312  size_t length = 2;
313  oss << "static const unsigned char is_categorical[] = {\n ";
314  for (int fid = 0; fid < num_feature; ++fid) {
315  if (node->is_categorical[fid]) {
316  common::WrapText(&oss, &length, "1, ", 2, 80);
317  } else {
318  common::WrapText(&oss, &length, "0, ", 2, 80);
319  }
320  }
321  oss << "\n};\n";
322  length = 2;
323  oss << "static const float threshold[] = {\n ";
324  for (const auto& e : node->cut_pts) {
325  for (tl_float v : e) {
326  common::WrapText(&oss, &length, common::ToString(v) + ", ", 2, 80);
327  }
328  }
329  oss << "\n};\n";
330  length = 2;
331  size_t accum = 0;
332  oss << "static const int th_begin[] = {\n ";
333  for (const auto& e : node->cut_pts) {
334  common::WrapText(&oss, &length, std::to_string(accum) + ", ", 2, 80);
335  accum += e.size();
336  }
337  oss << "\n};\n";
338  length = 2;
339  oss << "static const int th_len[] = {\n ";
340  for (const auto& e : node->cut_pts) {
341  common::WrapText(&oss, &length, std::to_string(e.size()) + ", ", 2, 80);
342  }
343  oss << "\n};\n";
344  #include "./native/quantize_func.h"
345  oss << quantize_func << files_[dest] << std::string(indent, ' ')
346  << "for (int i = 0; i < " << num_feature << "; ++i) {\n"
347  << std::string(indent + 2, ' ')
348  << "if (data[i].missing != -1 && !is_categorical[i]) {\n"
349  << std::string(indent + 4, ' ')
350  << "data[i].qvalue = quantize(data[i].fvalue, i);\n"
351  << std::string(indent + 2, ' ') + "}\n"
352  << std::string(indent, ' ') + "}\n";
353  files_[dest] = oss.str();
354  CHECK_EQ(node->children.size(), 1);
355  WalkAST(node->children[0], dest, indent);
356  }
357 
358  inline std::vector<uint64_t>
359  to_bitmap(const std::vector<uint32_t>& left_categories) const {
360  const size_t num_left_categories = left_categories.size();
361  const uint32_t max_left_category = left_categories[num_left_categories - 1];
362  std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
363  for (size_t i = 0; i < left_categories.size(); ++i) {
364  const uint32_t cat = left_categories[i];
365  const size_t idx = cat / 64;
366  const uint32_t offset = cat % 64;
367  bitmap[idx] |= (static_cast<uint64_t>(1) << offset);
368  }
369  return bitmap;
370  }
371 };
372 
374 .describe("AST-based compiler that produces C code")
375 .set_body([](const CompilerParam& param) -> Compiler* {
376  return new ASTNativeCompiler(param);
377  });
378 } // namespace compiler
379 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
branch annotator class
Definition: annotator.h:16
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
parameters for tree compiler
Definition: param.h:16
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:32
Interface of compiler that compiles a tree ensemble model.
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:157
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:70
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_native.cc:24
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
Branch annotation tools.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
int verbose
if >0, produce extra messages
Definition: param.h:34