treelite
recursive.cc
Go to the documentation of this file.
1 
8 #include <treelite/common.h>
9 #include <treelite/annotator.h>
10 #include <treelite/compiler.h>
11 #include <treelite/tree.h>
12 #include <treelite/semantic.h>
13 #include <dmlc/registry.h>
14 #include <queue>
15 #include <algorithm>
16 #include <iterator>
17 #include <iomanip>
18 #include <cmath>
19 #include "param.h"
20 #include "pred_transform.h"
21 
22 namespace {
23 
24 class NumericSplitCondition : public treelite::semantic::Condition {
25  public:
26  using NumericAdapter
27  = std::function<std::string(treelite::Operator, unsigned,
29  explicit NumericSplitCondition(const treelite::Tree::Node& node,
30  const NumericAdapter& numeric_adapter)
31  : split_index(node.split_index()), default_left(node.default_left()),
32  op(node.comparison_op()), threshold(node.threshold()),
33  numeric_adapter(numeric_adapter) {}
34  explicit NumericSplitCondition(const treelite::Tree::Node& node,
35  NumericAdapter&& numeric_adapter)
36  : split_index(node.split_index()), default_left(node.default_left()),
37  op(node.comparison_op()), threshold(node.threshold()),
38  numeric_adapter(std::move(numeric_adapter)) {}
39  CLONEABLE_BOILERPLATE(NumericSplitCondition)
40  inline std::string Compile() const override {
41  const std::string bitmap
42  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
43  return ((default_left) ? (std::string("!(") + bitmap + ") || ")
44  : (std::string(" (") + bitmap + ") && "))
45  + numeric_adapter(op, split_index, threshold);
46  }
47 
48  private:
49  unsigned split_index;
50  bool default_left;
52  treelite::tl_float threshold;
53  NumericAdapter numeric_adapter;
54 };
55 
56 class CategoricalSplitCondition : public treelite::semantic::Condition {
57  public:
58  explicit CategoricalSplitCondition(const treelite::Tree::Node& node)
59  : split_index(node.split_index()), default_left(node.default_left()),
60  categorical_bitmap(to_bitmap(node.left_categories())) {}
61  CLONEABLE_BOILERPLATE(CategoricalSplitCondition)
62  inline std::string Compile() const override {
63  const std::string bitmap
64  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
65  CHECK_GE(categorical_bitmap.size(), 1);
66  std::ostringstream comp;
67  comp << "(tmp = (unsigned int)(data[" << split_index << "].fvalue) ), "
68  << "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
69  << categorical_bitmap[0] << "U >> tmp) & 1) )";
70  for (size_t i = 1; i < categorical_bitmap.size(); ++i) {
71  comp << " || (tmp >= " << (i * 64)
72  << " && tmp < " << ((i + 1) * 64)
73  << " && (( (uint64_t)" << categorical_bitmap[i]
74  << "U >> (tmp - " << (i * 64) << ") ) & 1) )";
75  }
76  bool all_zeros = true;
77  for (uint64_t e : categorical_bitmap) {
78  all_zeros &= (e == 0);
79  }
80  return ((default_left) ? (std::string("!(") + bitmap + ") || (")
81  : (std::string(" (") + bitmap + ") && ("))
82  + (all_zeros ? std::string("0") : comp.str()) + ")";
83  }
84 
85  private:
86  unsigned split_index;
87  bool default_left;
88  std::vector<uint64_t> categorical_bitmap;
89 
90  inline std::vector<uint64_t> to_bitmap(const std::vector<uint32_t>& left_categories) const {
91  const size_t num_left_categories = left_categories.size();
92  const uint32_t max_left_category = left_categories[num_left_categories - 1];
93  std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
94  for (size_t i = 0; i < left_categories.size(); ++i) {
95  const uint32_t cat = left_categories[i];
96  const size_t idx = cat / 64;
97  const uint32_t offset = cat % 64;
98  bitmap[idx] |= (static_cast<uint64_t>(1) << offset);
99  }
100  return bitmap;
101  }
102 };
103 
104 struct GroupPolicy {
105  void Init(const treelite::Model& model);
106 
107  std::string GroupQueryFunction() const;
108  std::string Accumulator() const;
109  std::string AccumulateTranslationUnit(size_t unit_id) const;
110  std::vector<std::string> AccumulateLeaf(const treelite::Tree::Node& node,
111  size_t tree_id) const;
112  std::vector<std::string> Return() const;
113  std::vector<std::string> FinalReturn(size_t num_tree, float global_bias) const;
114  std::string Prototype() const;
115  std::string PrototypeTranslationUnit(size_t unit_id) const;
116 
117  int num_output_group;
118  bool random_forest_flag;
119 };
120 
121 } // namespace anonymous
122 
123 namespace treelite {
124 namespace compiler {
125 
126 DMLC_REGISTRY_FILE_TAG(recursive);
127 
128 std::vector<std::vector<tl_float>> ExtractCutPoints(const Model& model);
129 
130 struct Metadata {
131  int num_feature;
132  std::vector<std::vector<tl_float>> cut_pts;
133  std::vector<bool> is_categorical;
134 
135  inline void Init(const Model& model, bool extract_cut_pts = false) {
136  num_feature = model.num_feature;
137  is_categorical.clear();
138  is_categorical.resize(num_feature, false);
139  for (const Tree& tree : model.trees) {
140  for (unsigned e : tree.GetCategoricalFeatures()) {
141  is_categorical[e] = true;
142  }
143  }
144  if (extract_cut_pts) {
145  cut_pts = std::move(ExtractCutPoints(model));
146  }
147  }
148 };
149 
150 template <typename QuantizePolicy>
151 class RecursiveCompiler : public Compiler, private QuantizePolicy {
152  public:
153  explicit RecursiveCompiler(const CompilerParam& param)
154  : param(param) {
155  if (param.verbose > 0) {
156  LOG(INFO) << "Using RecursiveCompiler";
157  }
158  }
159 
168 
169  SemanticModel Compile(const Model& model) override {
170  Metadata info;
171  info.Init(model, QuantizePolicy::QuantizeFlag());
172  QuantizePolicy::Init(std::move(info));
173  group_policy.Init(model);
174 
175  std::vector<std::vector<size_t>> annotation;
176  bool annotate = false;
177  if (param.annotate_in != "NULL") {
178  BranchAnnotator annotator;
179  std::unique_ptr<dmlc::Stream> fi(
180  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
181  annotator.Load(fi.get());
182  annotation = annotator.Get();
183  annotate = true;
184  if (param.verbose > 0) {
185  LOG(INFO) << "Using branch annotation file `"
186  << param.annotate_in << "'";
187  }
188  }
189 
190  SemanticModel semantic_model;
191  SequenceBlock sequence;
192  sequence.Reserve(model.trees.size() + 3);
193  sequence.PushBack(PlainBlock(group_policy.Accumulator()));
194  sequence.PushBack(PlainBlock(QuantizePolicy::Preprocessing()));
195  if (param.parallel_comp > 0) {
196  LOG(INFO) << "Parallel compilation enabled; member trees will be "
197  << "divided into " << param.parallel_comp
198  << " translation units.";
199  const size_t nunit = param.parallel_comp; // # translation units
200  for (size_t unit_id = 0; unit_id < nunit; ++unit_id) {
201  sequence.PushBack(PlainBlock(
202  group_policy.AccumulateTranslationUnit(unit_id)));
203  }
204  } else {
205  LOG(INFO) << "Parallel compilation disabled; all member trees will be "
206  << "dump to a single source file. This may increase "
207  << "compilation time and memory usage.";
208  for (size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
209  const Tree& tree = model.trees[tree_id];
210  if (!annotation.empty()) {
211  sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
212  annotation[tree_id])));
213  } else {
214  sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
215  {})));
216  }
217  }
218  }
219  sequence.PushBack(PlainBlock(group_policy.FinalReturn(model.trees.size(),
220  model.param.global_bias)));
221 
222  FunctionBlock query_func("size_t get_num_output_group(void)",
223  PlainBlock(group_policy.GroupQueryFunction()),
224  &semantic_model.function_registry, true);
225  FunctionBlock query_func2("size_t get_num_feature(void)",
226  PlainBlock(std::string("return ") +
227  std::to_string(model.num_feature)+";"),
228  &semantic_model.function_registry, true);
229  FunctionBlock pred_transform_func(PredTransformPrototype(false),
230  PlainBlock(PredTransformFunction(model, false)),
231  &semantic_model.function_registry, true);
232  FunctionBlock pred_transform_batch_func(PredTransformPrototype(true),
233  PlainBlock(PredTransformFunction(model, true)),
234  &semantic_model.function_registry, true);
235  FunctionBlock main_func(group_policy.Prototype(),
236  std::move(sequence), &semantic_model.function_registry, true);
237  SequenceBlock main_file;
238  main_file.Reserve(5);
239  main_file.PushBack(std::move(query_func));
240  main_file.PushBack(std::move(query_func2));
241  main_file.PushBack(std::move(pred_transform_func));
242  main_file.PushBack(std::move(pred_transform_batch_func));
243  main_file.PushBack(std::move(main_func));
244  auto file_preamble = QuantizePolicy::ConstantsPreamble();
245  semantic_model.units.emplace_back(PlainBlock(file_preamble),
246  std::move(main_file));
247 
248  if (param.parallel_comp > 0) {
249  const size_t nunit = param.parallel_comp;
250  const size_t unit_size = (model.trees.size() + nunit - 1) / nunit;
251  for (size_t unit_id = 0; unit_id < nunit; ++unit_id) {
252  const size_t tree_begin = unit_id * unit_size;
253  const size_t tree_end = std::min((unit_id + 1) * unit_size,
254  model.trees.size());
255  SequenceBlock unit_seq;
256  if (tree_begin < tree_end) {
257  unit_seq.Reserve(tree_end - tree_begin + 2);
258  }
259  unit_seq.PushBack(PlainBlock(group_policy.Accumulator()));
260  for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
261  const Tree& tree = model.trees[tree_id];
262  if (!annotation.empty()) {
263  unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
264  annotation[tree_id])));
265  } else {
266  unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
267  {})));
268  }
269  }
270  unit_seq.PushBack(PlainBlock(group_policy.Return()));
271  FunctionBlock unit_func(group_policy.PrototypeTranslationUnit(unit_id),
272  std::move(unit_seq),
273  &semantic_model.function_registry);
274  semantic_model.units.emplace_back(PlainBlock(), std::move(unit_func));
275  }
276  }
277  std::vector<std::string> header{"#include <stdlib.h>",
278  "#include <string.h>",
279  "#include <math.h>",
280  "#include <stdint.h>"};
281  common::TransformPushBack(&header, QuantizePolicy::CommonHeader(),
282  [] (std::string line) { return line; });
283  if (annotate) {
284  header.emplace_back();
285 #if defined(__clang__) || defined(__GNUC__)
286  // only gcc and clang support __builtin_expect intrinsic
287  header.emplace_back("#define LIKELY(x) __builtin_expect(!!(x), 1)");
288  header.emplace_back("#define UNLIKELY(x) __builtin_expect(!!(x), 0)");
289 #else
290  header.emplace_back("#define LIKELY(x) (x)");
291  header.emplace_back("#define UNLIKELY(x) (x)");
292 #endif
293  }
294  semantic_model.common_header
295  = std::move(common::make_unique<PlainBlock>(header));
296  return semantic_model;
297  }
298 
299  private:
300  CompilerParam param;
301  GroupPolicy group_policy;
302 
303  std::unique_ptr<CodeBlock> WalkTree(const Tree& tree, size_t tree_id,
304  const std::vector<size_t>& counts) const {
305  return WalkTree_(tree, tree_id, counts, 0);
306  }
307 
308  std::unique_ptr<CodeBlock> WalkTree_(const Tree& tree, size_t tree_id,
309  const std::vector<size_t>& counts,
310  int nid) const {
311  using semantic::BranchHint;
312  const Tree::Node& node = tree[nid];
313  if (node.is_leaf()) {
314  return std::unique_ptr<CodeBlock>(new PlainBlock(
315  group_policy.AccumulateLeaf(node, tree_id)));
316  } else {
317  BranchHint branch_hint = BranchHint::kNone;
318  if (!counts.empty()) {
319  const size_t left_count = counts[node.cleft()];
320  const size_t right_count = counts[node.cright()];
321  branch_hint = (left_count > right_count) ? BranchHint::kLikely
322  : BranchHint::kUnlikely;
323  }
324  std::unique_ptr<Condition> condition(nullptr);
325  if (node.split_type() == SplitFeatureType::kNumerical) {
326  condition = common::make_unique<NumericSplitCondition>(node,
327  QuantizePolicy::NumericAdapter());
328  } else {
329  condition = common::make_unique<CategoricalSplitCondition>(node);
330  }
331  return std::unique_ptr<CodeBlock>(new IfElseBlock(
332  common::MoveUniquePtr(condition),
333  common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.cleft())),
334  common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.cright())),
335  branch_hint)
336  );
337  }
338  }
339 };
340 
342  protected:
343  void Init(const Metadata& info) {
344  this->info = info;
345  }
346  void Init(Metadata&& info) {
347  this->info = std::move(info);
348  }
349  const Metadata& GetInfo() const {
350  return info;
351  }
352  MetadataStore() = default;
353  MetadataStore(const MetadataStore& other) = default;
354  MetadataStore(MetadataStore&& other) = default;
355  private:
356  Metadata info;
357 };
358 
359 class NoQuantize : private MetadataStore {
360  protected:
361  template <typename... Args>
362  void Init(Args&&... args) {
363  MetadataStore::Init(std::forward<Args>(args)...);
364  }
365  NumericSplitCondition::NumericAdapter NumericAdapter() const {
366  return [] (Operator op, unsigned split_index, tl_float threshold) {
367  std::ostringstream oss;
368  if (!std::isfinite(threshold)) {
369  // According to IEEE 754, the result of comparison [lhs] < infinity
370  // must be identical for all finite [lhs]. Same goes for operator >.
371  oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
372  } else {
373  // to restore default precision
374  const std::streamsize ss = std::cout.precision();
375  oss << "data[" << split_index << "].fvalue "
376  << semantic::OpName(op) << " "
377  << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
378  << threshold
379  << std::setprecision(ss);
380  }
381  return oss.str();
382  };
383  }
384  std::vector<std::string> CommonHeader() const {
385  return {"",
386  "union Entry {",
387  " int missing;",
388  " float fvalue;",
389  "};"};
390  }
391  std::vector<std::string> ConstantsPreamble() const {
392  return {};
393  }
394  std::vector<std::string> Preprocessing() const {
395  return {};
396  }
397  bool QuantizeFlag() const {
398  return false;
399  }
400 };
401 
402 class Quantize : private MetadataStore {
403  protected:
404  template <typename... Args>
405  void Init(Args&&... args) {
406  MetadataStore::Init(std::forward<Args>(args)...);
407  quant_preamble = {
408  std::string("for (int i = 0; i < ")
409  + std::to_string(GetInfo().num_feature) + "; ++i) {",
410  " if (data[i].missing != -1 && !is_categorical[i]) {",
411  " data[i].qvalue = quantize(data[i].fvalue, i);",
412  " }",
413  "}"};
414  }
415  NumericSplitCondition::NumericAdapter NumericAdapter() const {
416  const std::vector<std::vector<tl_float>>& cut_pts = GetInfo().cut_pts;
417  return [&cut_pts] (Operator op, unsigned split_index,
418  tl_float threshold) {
419  std::ostringstream oss;
420  const auto& v = cut_pts[split_index];
421  if (!std::isfinite(threshold)) {
422  // According to IEEE 754, the result of comparison [lhs] < infinity
423  // must be identical for all finite [lhs]. Same goes for operator >.
424  oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
425  } else {
426  auto loc = common::binary_search(v.begin(), v.end(), threshold);
427  CHECK(loc != v.end());
428  oss << "data[" << split_index << "].qvalue " << semantic::OpName(op)
429  << " " << static_cast<size_t>(loc - v.begin()) * 2;
430  }
431  return oss.str();
432  };
433  }
434  std::vector<std::string> CommonHeader() const {
435  return {"",
436  "union Entry {",
437  " int missing;",
438  " float fvalue;",
439  " int qvalue;",
440  "};"};
441  }
442  std::vector<std::string> ConstantsPreamble() const {
443  std::vector<std::string> ret;
444  ret.emplace_back("static const unsigned char is_categorical[] = {");
445  {
446  std::ostringstream oss, oss2;
447  size_t length = 2;
448  oss << " ";
449  const int num_feature = GetInfo().num_feature;
450  const auto& is_categorical = GetInfo().is_categorical;
451  for (int fid = 0; fid < num_feature; ++fid) {
452  if (is_categorical[fid]) {
453  common::WrapText(&oss, &length, "1", 80);
454  } else {
455  common::WrapText(&oss, &length, "0", 80);
456  }
457  }
458  ret.push_back(oss.str());
459  ret.emplace_back("};");
460  ret.emplace_back();
461  }
462  ret.emplace_back("static const float threshold[] = {");
463  {
464  std::ostringstream oss, oss2;
465  size_t length = 2;
466  oss << " ";
467  for (const auto& e : GetInfo().cut_pts) {
468  for (const auto& value : e) {
469  oss2.clear(); oss2.str(std::string()); oss2 << value;
470  common::WrapText(&oss, &length, oss2.str(), 80);
471  }
472  }
473  ret.push_back(oss.str());
474  ret.emplace_back("};");
475  ret.emplace_back();
476  }
477  ret.emplace_back("static const int th_begin[] = {");
478  {
479  std::ostringstream oss, oss2;
480  size_t length = 2;
481  size_t accum = 0;
482  oss << " ";
483  for (const auto& e : GetInfo().cut_pts) {
484  oss2.clear(); oss2.str(std::string()); oss2 << accum;
485  common::WrapText(&oss, &length, oss2.str(), 80);
486  accum += e.size();
487  }
488  ret.push_back(oss.str());
489  ret.emplace_back("};");
490  ret.emplace_back();
491  }
492  ret.emplace_back("static const int th_len[] = {");
493  {
494  std::ostringstream oss, oss2;
495  size_t length = 2;
496  oss << " ";
497  for (const auto& e : GetInfo().cut_pts) {
498  oss2.clear(); oss2.str(std::string()); oss2 << e.size();
499  common::WrapText(&oss, &length, oss2.str(), 80);
500  }
501  ret.push_back(oss.str());
502  ret.emplace_back("};");
503  ret.emplace_back();
504  }
505 
506  auto func = semantic::FunctionBlock(
507  "static inline int quantize(float val, unsigned fid)",
509  {"const float* array = &threshold[th_begin[fid]];",
510  "int len = th_len[fid];",
511  "int low = 0;",
512  "int high = len;",
513  "int mid;",
514  "float mval;",
515  "if (val < array[0]) {",
516  " return -10;",
517  "}",
518  "while (low + 1 < high) {",
519  " mid = (low + high) / 2;",
520  " mval = array[mid];",
521  " if (val == mval) {",
522  " return mid * 2;",
523  " } else if (val < mval) {",
524  " high = mid;",
525  " } else {",
526  " low = mid;",
527  " }",
528  "}",
529  "if (array[low] == val) {",
530  " return low * 2;",
531  "} else if (high == len) {",
532  " return len * 2;",
533  "} else {",
534  " return low * 2 + 1;",
535  "}"}), nullptr).Compile();
536  ret.insert(ret.end(), func.begin(), func.end());
537  return ret;
538  }
539  std::vector<std::string> Preprocessing() const {
540  return quant_preamble;
541  }
542  bool QuantizeFlag() const {
543  return true;
544  }
545  private:
546  std::vector<std::string> quant_preamble;
547 };
548 
549 inline std::vector<std::vector<tl_float>>
550 ExtractCutPoints(const Model& model) {
551  std::vector<std::vector<tl_float>> cut_pts;
552 
553  std::vector<std::set<tl_float>> thresh_;
554  cut_pts.resize(model.num_feature);
555  thresh_.resize(model.num_feature);
556  for (size_t i = 0; i < model.trees.size(); ++i) {
557  const Tree& tree = model.trees[i];
558  std::queue<int> Q;
559  Q.push(0);
560  while (!Q.empty()) {
561  const int nid = Q.front();
562  const Tree::Node& node = tree[nid];
563  Q.pop();
564  if (!node.is_leaf()) {
565  if (node.split_type() == SplitFeatureType::kNumerical) {
566  const tl_float threshold = node.threshold();
567  const unsigned split_index = node.split_index();
568  if (std::isfinite(threshold)) { // ignore infinity
569  thresh_[split_index].insert(threshold);
570  }
571  } else {
572  CHECK(node.split_type() == SplitFeatureType::kCategorical);
573  }
574  Q.push(node.cleft());
575  Q.push(node.cright());
576  }
577  }
578  }
579  for (int i = 0; i < model.num_feature; ++i) {
580  std::copy(thresh_[i].begin(), thresh_[i].end(),
581  std::back_inserter(cut_pts[i]));
582  }
583  return cut_pts;
584 }
585 
587 .describe("A compiler with a recursive approach")
588 .set_body([](const CompilerParam& param) -> Compiler* {
589  if (param.quantize > 0) {
590  return new RecursiveCompiler<Quantize>(param);
591  } else {
592  return new RecursiveCompiler<NoQuantize>(param);
593  }
594  });
595 } // namespace compiler
596 } // namespace treelite
597 
598 namespace {
599 
600 
601 inline void
602 GroupPolicy::Init(const treelite::Model& model) {
603  this->num_output_group = model.num_output_group;
604  this->random_forest_flag = model.random_forest_flag;
605 }
606 
607 inline std::string
608 GroupPolicy::GroupQueryFunction() const {
609  return "return " + std::to_string(num_output_group) + ";";
610 }
611 
612 inline std::string
613 GroupPolicy::Accumulator() const {
614  if (num_output_group > 1) {
615  return std::string("float sum[") + std::to_string(num_output_group)
616  + "] = {0.0f};\n unsigned int tmp;";
617  } else {
618  return "float sum = 0.0f;\n unsigned int tmp;";
619  }
620 }
621 
622 inline std::string
623 GroupPolicy::AccumulateTranslationUnit(size_t unit_id) const {
624  if (num_output_group > 1) {
625  return std::string("predict_margin_multiclass_unit")
626  + std::to_string(unit_id) + "(data, sum);";
627  } else {
628  return std::string("sum += predict_margin_unit")
629  + std::to_string(unit_id) + "(data);";
630  }
631 }
632 
633 inline std::vector<std::string>
634 GroupPolicy::AccumulateLeaf(const treelite::Tree::Node& node,
635  size_t tree_id) const {
636  if (num_output_group > 1) {
637  if (random_forest_flag) {
638  // multi-class classification with random forest
639  const std::vector<treelite::tl_float>& leaf_vector = node.leaf_vector();
640  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group))
641  << "Ill-formed model: leaf vector must be of length [num_output_group]";
642  std::vector<std::string> lines;
643  lines.reserve(num_output_group);
644  for (int group_id = 0; group_id < num_output_group; ++group_id) {
645  lines.push_back(std::string("sum[") + std::to_string(group_id)
646  + "] += (float)"
647  + treelite::common::ToString(leaf_vector[group_id]) + ";");
648  }
649  return lines;
650  } else {
651  // multi-class classification with gradient boosted trees
652  const treelite::tl_float leaf_value = node.leaf_value();
653  return { std::string("sum[") + std::to_string(tree_id % num_output_group)
654  + "] += (float)" + treelite::common::ToString(leaf_value) + ";" };
655  }
656  } else {
657  const treelite::tl_float leaf_value = node.leaf_value();
658  return {std::string("sum += (float)")
659  + treelite::common::ToString(leaf_value) + ";" };
660  }
661 }
662 
663 inline std::vector<std::string>
664 GroupPolicy::Return() const {
665  if (num_output_group > 1) {
666  return {std::string("for (int i = 0; i < ")
667  + std::to_string(num_output_group) + "; ++i) {",
668  " result[i] += sum[i];",
669  "}" };
670  } else {
671  return { "return sum;" };
672  }
673 }
674 
675 inline std::vector<std::string>
676 GroupPolicy::FinalReturn(size_t num_tree, float global_bias) const {
677  if (num_output_group > 1) {
678  if (random_forest_flag) {
679  // multi-class classification with random forest
680  return {std::string("for (int i = 0; i < ")
681  + std::to_string(num_output_group) + "; ++i) {",
682  std::string(" result[i] = sum[i] / ")
683  + std::to_string(num_tree) + " + ("
684  + treelite::common::ToString(global_bias) + ");",
685  "}"};
686  } else {
687  // multi-class classification with gradient boosted trees
688  return {std::string("for (int i = 0; i < ")
689  + std::to_string(num_output_group) + "; ++i) {",
690  " result[i] = sum[i] + ("
691  + treelite::common::ToString(global_bias) + ");",
692  "}"};
693  }
694  } else {
695  if (random_forest_flag) {
696  return { std::string("return sum / ") + std::to_string(num_tree) + " + ("
697  + treelite::common::ToString(global_bias) + ");" };
698  } else {
699  return { std::string("return sum + (")
700  + treelite::common::ToString(global_bias) + ");" };
701  }
702  }
703 }
704 
705 inline std::string
706 GroupPolicy::Prototype() const {
707  if (num_output_group > 1) {
708  return "void predict_margin_multiclass(union Entry* data, float* result)";
709  } else {
710  return "float predict_margin(union Entry* data)";
711  }
712 }
713 
714 inline std::string
715 GroupPolicy::PrototypeTranslationUnit(size_t unit_id) const {
716  if (num_output_group > 1) {
717  return std::string("void predict_margin_multiclass_unit")
718  + std::to_string(unit_id) + "(union Entry* data, float* result)";
719  } else {
720  return std::string("float predict_margin_unit")
721  + std::to_string(unit_id) + "(union Entry* data)";
722  }
723 }
724 
725 } // namespace anonymous
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
plain code block containing one or more lines of code
Definition: semantic.h:118
branch annotator class
Definition: annotator.h:16
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
fundamental block in semantic model. All code blocks should inherit from this class.
Definition: semantic.h:65
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:353
parameters for tree compiler
Definition: param.h:16
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
ModelParam param
extra parameters
Definition: tree.h:366
model structure for tree
SemanticModel Compile(const Model &model) override
convert tree ensemble model into semantic model
Definition: recursive.cc:169
in-memory representation of a decision tree
Definition: tree.h:19
float global_bias
global bias of the model
Definition: tree.h:322
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:33
BranchHint
enum class to store branch annotation
Definition: semantic.h:17
tl_float threshold() const
Definition: tree.h:67
Interface of compiler that translates a tree ensemble model into a semantic model.
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:57
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
int cright() const
index of right child
Definition: tree.h:30
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:138
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
function block with a prototype and code body. Its prototype can optionally be registered with a func...
Definition: semantic.h:138
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:71
if-else statement with condition may store a branch hint (>50% or <50% likely)
Definition: semantic.h:194
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:364
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
Definition: common.h:33
a conditional expression
Definition: semantic.h:184
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
sequence of one or more code blocks
Definition: semantic.h:171
tl_float leaf_value() const
Definition: tree.h:50
Branch annotation tools.
Some useful utilities.
int cleft() const
index of left child
Definition: tree.h:26
tools to define prediction transform function
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:258
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
Building blocks for semantic model of tree prediction code.
translation unit is abstraction of a source file
Definition: semantic.h:72
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
semantic model consists of a header, function registry, and a list of translation units ...
Definition: semantic.h:90
int verbose
if >0, produce extra messages
Definition: param.h:34
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:358
Operator
comparison operators
Definition: base.h:23