treelite
recursive.cc
Go to the documentation of this file.
1 
8 #include <treelite/common.h>
9 #include <treelite/annotator.h>
10 #include <treelite/compiler.h>
11 #include <treelite/tree.h>
12 #include <treelite/semantic.h>
13 #include <dmlc/registry.h>
14 #include <queue>
15 #include <algorithm>
16 #include <iterator>
17 #include <iomanip>
18 #include <cmath>
19 #include "param.h"
20 #include "pred_transform.h"
21 
22 namespace {
23 
24 class NumericSplitCondition : public treelite::semantic::Condition {
25  public:
26  using NumericAdapter
27  = std::function<std::string(treelite::Operator, unsigned,
29  explicit NumericSplitCondition(const treelite::Tree::Node& node,
30  const NumericAdapter& numeric_adapter)
31  : split_index(node.split_index()), default_left(node.default_left()),
32  op(node.comparison_op()), threshold(node.threshold()),
33  numeric_adapter(numeric_adapter) {}
34  explicit NumericSplitCondition(const treelite::Tree::Node& node,
35  NumericAdapter&& numeric_adapter)
36  : split_index(node.split_index()), default_left(node.default_left()),
37  op(node.comparison_op()), threshold(node.threshold()),
38  numeric_adapter(std::move(numeric_adapter)) {}
39  CLONEABLE_BOILERPLATE(NumericSplitCondition)
40  inline std::string Compile() const override {
41  const std::string bitmap
42  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
43  return ((default_left) ? (std::string("!(") + bitmap + ") || ")
44  : (std::string(" (") + bitmap + ") && "))
45  + numeric_adapter(op, split_index, threshold);
46  }
47 
48  private:
49  unsigned split_index;
50  bool default_left;
52  treelite::tl_float threshold;
53  NumericAdapter numeric_adapter;
54 };
55 
56 class CategoricalSplitCondition : public treelite::semantic::Condition {
57  public:
58  explicit CategoricalSplitCondition(const treelite::Tree::Node& node)
59  : split_index(node.split_index()), default_left(node.default_left()),
60  categorical_bitmap(to_bitmap(node.left_categories())) {}
61  CLONEABLE_BOILERPLATE(CategoricalSplitCondition)
62  inline std::string Compile() const override {
63  const std::string bitmap
64  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
65  const std::string comp
66  = std::string("((") + std::to_string(categorical_bitmap)
67  + "U >> (unsigned int)(data[" + std::to_string(split_index)
68  + "].fvalue)) & 1)";
69  return ((default_left) ? (std::string("!(") + bitmap + ") || ")
70  : (std::string(" (") + bitmap + ") && "))
71  + ((categorical_bitmap == 0) ? std::string("0") : comp);
72  }
73 
74  private:
75  unsigned split_index;
76  bool default_left;
77  uint64_t categorical_bitmap;
78 
79  inline uint64_t to_bitmap(const std::vector<uint8_t>& left_categories) const {
80  uint64_t result = 0;
81  for (uint8_t e : left_categories) {
82  CHECK_LT(e, 64) << "Cannot have more than 64 categories in a feature";
83  result |= (static_cast<uint64_t>(1) << e);
84  }
85  return result;
86  }
87 };
88 
89 struct GroupPolicy {
90  void Init(const treelite::Model& model);
91 
92  std::string GroupQueryFunction() const;
93  std::string Accumulator() const;
94  std::string AccumulateTranslationUnit(size_t unit_id) const;
95  std::vector<std::string> AccumulateLeaf(const treelite::Tree::Node& node,
96  size_t tree_id) const;
97  std::vector<std::string> Return() const;
98  std::vector<std::string> FinalReturn(size_t num_tree, float global_bias) const;
99  std::string Prototype() const;
100  std::string PrototypeTranslationUnit(size_t unit_id) const;
101 
102  int num_output_group;
103  bool random_forest_flag;
104 };
105 
106 } // namespace anonymous
107 
108 namespace treelite {
109 namespace compiler {
110 
111 DMLC_REGISTRY_FILE_TAG(recursive);
112 
113 std::vector<std::vector<tl_float>> ExtractCutPoints(const Model& model);
114 
115 struct Metadata {
116  int num_feature;
117  std::vector<std::vector<tl_float>> cut_pts;
118  std::vector<bool> is_categorical;
119 
120  inline void Init(const Model& model, bool extract_cut_pts = false) {
121  num_feature = model.num_feature;
122  is_categorical.clear();
123  is_categorical.resize(num_feature, false);
124  for (const Tree& tree : model.trees) {
125  for (unsigned e : tree.GetCategoricalFeatures()) {
126  is_categorical[e] = true;
127  }
128  }
129  if (extract_cut_pts) {
130  cut_pts = std::move(ExtractCutPoints(model));
131  }
132  }
133 };
134 
135 template <typename QuantizePolicy>
136 class RecursiveCompiler : public Compiler, private QuantizePolicy {
137  public:
138  explicit RecursiveCompiler(const CompilerParam& param)
139  : param(param) {
140  if (param.verbose > 0) {
141  LOG(INFO) << "Using RecursiveCompiler";
142  }
143  }
144 
153 
154  SemanticModel Compile(const Model& model) override {
155  Metadata info;
156  info.Init(model, QuantizePolicy::QuantizeFlag());
157  QuantizePolicy::Init(std::move(info));
158  group_policy.Init(model);
159 
160  std::vector<std::vector<size_t>> annotation;
161  bool annotate = false;
162  if (param.annotate_in != "NULL") {
163  BranchAnnotator annotator;
164  std::unique_ptr<dmlc::Stream> fi(
165  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
166  annotator.Load(fi.get());
167  annotation = annotator.Get();
168  annotate = true;
169  if (param.verbose > 0) {
170  LOG(INFO) << "Using branch annotation file `"
171  << param.annotate_in << "\"";
172  }
173  }
174 
175  SemanticModel semantic_model;
176  SequenceBlock sequence;
177  sequence.Reserve(model.trees.size() + 3);
178  sequence.PushBack(PlainBlock(group_policy.Accumulator()));
179  sequence.PushBack(PlainBlock(QuantizePolicy::Preprocessing()));
180  if (param.parallel_comp > 0) {
181  LOG(INFO) << "Parallel compilation enabled; member trees will be "
182  << "divided into " << param.parallel_comp
183  << " translation units.";
184  const size_t nunit = param.parallel_comp; // # translation units
185  for (size_t unit_id = 0; unit_id < nunit; ++unit_id) {
186  sequence.PushBack(PlainBlock(
187  group_policy.AccumulateTranslationUnit(unit_id)));
188  }
189  } else {
190  LOG(INFO) << "Parallel compilation disabled; all member trees will be "
191  << "dump to a single source file. This may increase "
192  << "compilation time and memory usage.";
193  for (size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
194  const Tree& tree = model.trees[tree_id];
195  if (!annotation.empty()) {
196  sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
197  annotation[tree_id])));
198  } else {
199  sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
200  {})));
201  }
202  }
203  }
204  sequence.PushBack(PlainBlock(group_policy.FinalReturn(model.trees.size(),
205  model.param.global_bias)));
206 
207  FunctionBlock query_func("size_t get_num_output_group(void)",
208  PlainBlock(group_policy.GroupQueryFunction()),
209  &semantic_model.function_registry, true);
210  FunctionBlock query_func2("size_t get_num_feature(void)",
211  PlainBlock(std::string("return ") +
212  std::to_string(model.num_feature)+";"),
213  &semantic_model.function_registry, true);
214  FunctionBlock pred_transform_func(PredTransformPrototype(false),
215  PlainBlock(PredTransformFunction(model, false)),
216  &semantic_model.function_registry, true);
217  FunctionBlock pred_transform_batch_func(PredTransformPrototype(true),
218  PlainBlock(PredTransformFunction(model, true)),
219  &semantic_model.function_registry, true);
220  FunctionBlock main_func(group_policy.Prototype(),
221  std::move(sequence), &semantic_model.function_registry, true);
222  SequenceBlock main_file;
223  main_file.Reserve(5);
224  main_file.PushBack(std::move(query_func));
225  main_file.PushBack(std::move(query_func2));
226  main_file.PushBack(std::move(pred_transform_func));
227  main_file.PushBack(std::move(pred_transform_batch_func));
228  main_file.PushBack(std::move(main_func));
229  auto file_preamble = QuantizePolicy::ConstantsPreamble();
230  semantic_model.units.emplace_back(PlainBlock(file_preamble),
231  std::move(main_file));
232 
233  if (param.parallel_comp > 0) {
234  const size_t nunit = param.parallel_comp;
235  const size_t unit_size = (model.trees.size() + nunit - 1) / nunit;
236  for (size_t unit_id = 0; unit_id < nunit; ++unit_id) {
237  const size_t tree_begin = unit_id * unit_size;
238  const size_t tree_end = std::min((unit_id + 1) * unit_size,
239  model.trees.size());
240  SequenceBlock unit_seq;
241  unit_seq.Reserve(tree_end - tree_begin + 2);
242  unit_seq.PushBack(PlainBlock(group_policy.Accumulator()));
243  for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
244  const Tree& tree = model.trees[tree_id];
245  if (!annotation.empty()) {
246  unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
247  annotation[tree_id])));
248  } else {
249  unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
250  {})));
251  }
252  }
253  unit_seq.PushBack(PlainBlock(group_policy.Return()));
254  FunctionBlock unit_func(group_policy.PrototypeTranslationUnit(unit_id),
255  std::move(unit_seq),
256  &semantic_model.function_registry);
257  semantic_model.units.emplace_back(PlainBlock(), std::move(unit_func));
258  }
259  }
260  auto header = QuantizePolicy::CommonHeader();
261  if (annotate) {
262  header.emplace_back();
263 #if defined(__clang__) || defined(__GNUC__)
264  // only gcc and clang support __builtin_expect intrinsic
265  header.emplace_back("#define LIKELY(x) __builtin_expect(!!(x), 1)");
266  header.emplace_back("#define UNLIKELY(x) __builtin_expect(!!(x), 0)");
267 #else
268  header.emplace_back("#define LIKELY(x) (x)");
269  header.emplace_back("#define UNLIKELY(x) (x)");
270 #endif
271  }
272  semantic_model.common_header
273  = std::move(common::make_unique<PlainBlock>(header));
274  return semantic_model;
275  }
276 
277  private:
278  CompilerParam param;
279  GroupPolicy group_policy;
280 
281  std::unique_ptr<CodeBlock> WalkTree(const Tree& tree, size_t tree_id,
282  const std::vector<size_t>& counts) const {
283  return WalkTree_(tree, tree_id, counts, 0);
284  }
285 
286  std::unique_ptr<CodeBlock> WalkTree_(const Tree& tree, size_t tree_id,
287  const std::vector<size_t>& counts,
288  int nid) const {
289  using semantic::BranchHint;
290  const Tree::Node& node = tree[nid];
291  if (node.is_leaf()) {
292  return std::unique_ptr<CodeBlock>(new PlainBlock(
293  group_policy.AccumulateLeaf(node, tree_id)));
294  } else {
295  BranchHint branch_hint = BranchHint::kNone;
296  if (!counts.empty()) {
297  const size_t left_count = counts[node.cleft()];
298  const size_t right_count = counts[node.cright()];
299  branch_hint = (left_count > right_count) ? BranchHint::kLikely
300  : BranchHint::kUnlikely;
301  }
302  std::unique_ptr<Condition> condition(nullptr);
303  if (node.split_type() == SplitFeatureType::kNumerical) {
304  condition = common::make_unique<NumericSplitCondition>(node,
305  QuantizePolicy::NumericAdapter());
306  } else {
307  condition = common::make_unique<CategoricalSplitCondition>(node);
308  }
309  return std::unique_ptr<CodeBlock>(new IfElseBlock(
310  common::MoveUniquePtr(condition),
311  common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.cleft())),
312  common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.cright())),
313  branch_hint)
314  );
315  }
316  }
317 };
318 
320  protected:
321  void Init(const Metadata& info) {
322  this->info = info;
323  }
324  void Init(Metadata&& info) {
325  this->info = std::move(info);
326  }
327  const Metadata& GetInfo() const {
328  return info;
329  }
330  MetadataStore() = default;
331  MetadataStore(const MetadataStore& other) = default;
332  MetadataStore(MetadataStore&& other) = default;
333  private:
334  Metadata info;
335 };
336 
337 class NoQuantize : private MetadataStore {
338  protected:
339  template <typename... Args>
340  void Init(Args&&... args) {
341  MetadataStore::Init(std::forward<Args>(args)...);
342  }
343  NumericSplitCondition::NumericAdapter NumericAdapter() const {
344  return [] (Operator op, unsigned split_index, tl_float threshold) {
345  std::ostringstream oss;
346  if (!std::isfinite(threshold)) {
347  // According to IEEE 754, the result of comparison [lhs] < infinity
348  // must be identical for all finite [lhs]. Same goes for operator >.
349  oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
350  } else {
351  // to restore default precision
352  const std::streamsize ss = std::cout.precision();
353  oss << "data[" << split_index << "].fvalue "
354  << semantic::OpName(op) << " "
355  << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
356  << threshold
357  << std::setprecision(ss);
358  }
359  return oss.str();
360  };
361  }
362  std::vector<std::string> CommonHeader() const {
363  return {"#include <stdlib.h>",
364  "#include <string.h>",
365  "#include <math.h>",
366  "#include <stdint.h>",
367  "",
368  "union Entry {",
369  " int missing;",
370  " float fvalue;",
371  "};"};
372  }
373  std::vector<std::string> ConstantsPreamble() const {
374  return {};
375  }
376  std::vector<std::string> Preprocessing() const {
377  return {};
378  }
379  bool QuantizeFlag() const {
380  return false;
381  }
382 };
383 
384 class Quantize : private MetadataStore {
385  protected:
386  template <typename... Args>
387  void Init(Args&&... args) {
388  MetadataStore::Init(std::forward<Args>(args)...);
389  quant_preamble = {
390  std::string("for (int i = 0; i < ")
391  + std::to_string(GetInfo().num_feature) + "; ++i) {",
392  " if (data[i].missing != -1 && !is_categorical[i]) {",
393  " data[i].qvalue = quantize(data[i].fvalue, i);",
394  " }",
395  "}"};
396  }
397  NumericSplitCondition::NumericAdapter NumericAdapter() const {
398  const std::vector<std::vector<tl_float>>& cut_pts = GetInfo().cut_pts;
399  return [&cut_pts] (Operator op, unsigned split_index,
400  tl_float threshold) {
401  std::ostringstream oss;
402  const auto& v = cut_pts[split_index];
403  if (!std::isfinite(threshold)) {
404  // According to IEEE 754, the result of comparison [lhs] < infinity
405  // must be identical for all finite [lhs]. Same goes for operator >.
406  oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
407  } else {
408  auto loc = common::binary_search(v.begin(), v.end(), threshold);
409  CHECK(loc != v.end());
410  oss << "data[" << split_index << "].qvalue " << semantic::OpName(op)
411  << " " << static_cast<size_t>(loc - v.begin()) * 2;
412  }
413  return oss.str();
414  };
415  }
416  std::vector<std::string> CommonHeader() const {
417  return {"#include <stdlib.h>",
418  "#include <string.h>",
419  "#include <math.h>",
420  "#include <stdint.h>",
421  "",
422  "union Entry {",
423  " int missing;",
424  " float fvalue;",
425  " int qvalue;",
426  "};"};
427  }
428  std::vector<std::string> ConstantsPreamble() const {
429  std::vector<std::string> ret;
430  ret.emplace_back("static const unsigned char is_categorical[] = {");
431  {
432  std::ostringstream oss, oss2;
433  size_t length = 2;
434  oss << " ";
435  const int num_feature = GetInfo().num_feature;
436  const auto& is_categorical = GetInfo().is_categorical;
437  for (int fid = 0; fid < num_feature; ++fid) {
438  if (is_categorical[fid]) {
439  common::WrapText(&oss, &length, "1", 80);
440  } else {
441  common::WrapText(&oss, &length, "0", 80);
442  }
443  }
444  ret.push_back(oss.str());
445  ret.emplace_back("};");
446  ret.emplace_back();
447  }
448  ret.emplace_back("static const float threshold[] = {");
449  {
450  std::ostringstream oss, oss2;
451  size_t length = 2;
452  oss << " ";
453  for (const auto& e : GetInfo().cut_pts) {
454  for (const auto& value : e) {
455  oss2.clear(); oss2.str(std::string()); oss2 << value;
456  common::WrapText(&oss, &length, oss2.str(), 80);
457  }
458  }
459  ret.push_back(oss.str());
460  ret.emplace_back("};");
461  ret.emplace_back();
462  }
463  ret.emplace_back("static const int th_begin[] = {");
464  {
465  std::ostringstream oss, oss2;
466  size_t length = 2;
467  size_t accum = 0;
468  oss << " ";
469  for (const auto& e : GetInfo().cut_pts) {
470  oss2.clear(); oss2.str(std::string()); oss2 << accum;
471  common::WrapText(&oss, &length, oss2.str(), 80);
472  accum += e.size();
473  }
474  ret.push_back(oss.str());
475  ret.emplace_back("};");
476  ret.emplace_back();
477  }
478  ret.emplace_back("static const int th_len[] = {");
479  {
480  std::ostringstream oss, oss2;
481  size_t length = 2;
482  oss << " ";
483  for (const auto& e : GetInfo().cut_pts) {
484  oss2.clear(); oss2.str(std::string()); oss2 << e.size();
485  common::WrapText(&oss, &length, oss2.str(), 80);
486  }
487  ret.push_back(oss.str());
488  ret.emplace_back("};");
489  ret.emplace_back();
490  }
491 
492  auto func = semantic::FunctionBlock(
493  "static inline int quantize(float val, unsigned fid)",
495  {"const float* array = &threshold[th_begin[fid]];",
496  "int len = th_len[fid];",
497  "int low = 0;",
498  "int high = len;",
499  "int mid;",
500  "float mval;",
501  "if (val < array[0]) {",
502  " return -10;",
503  "}",
504  "while (low + 1 < high) {",
505  " mid = (low + high) / 2;",
506  " mval = array[mid];",
507  " if (val == mval) {",
508  " return mid * 2;",
509  " } else if (val < mval) {",
510  " high = mid;",
511  " } else {",
512  " low = mid;",
513  " }",
514  "}",
515  "if (array[low] == val) {",
516  " return low * 2;",
517  "} else if (high == len) {",
518  " return len * 2;",
519  "} else {",
520  " return low * 2 + 1;",
521  "}"}), nullptr).Compile();
522  ret.insert(ret.end(), func.begin(), func.end());
523  return ret;
524  }
525  std::vector<std::string> Preprocessing() const {
526  return quant_preamble;
527  }
528  bool QuantizeFlag() const {
529  return true;
530  }
531  private:
532  std::vector<std::string> quant_preamble;
533 };
534 
535 inline std::vector<std::vector<tl_float>>
536 ExtractCutPoints(const Model& model) {
537  std::vector<std::vector<tl_float>> cut_pts;
538 
539  std::vector<std::set<tl_float>> thresh_;
540  cut_pts.resize(model.num_feature);
541  thresh_.resize(model.num_feature);
542  for (size_t i = 0; i < model.trees.size(); ++i) {
543  const Tree& tree = model.trees[i];
544  std::queue<int> Q;
545  Q.push(0);
546  while (!Q.empty()) {
547  const int nid = Q.front();
548  const Tree::Node& node = tree[nid];
549  Q.pop();
550  if (!node.is_leaf()) {
551  if (node.split_type() == SplitFeatureType::kNumerical) {
552  const tl_float threshold = node.threshold();
553  const unsigned split_index = node.split_index();
554  if (std::isfinite(threshold)) { // ignore infinity
555  thresh_[split_index].insert(threshold);
556  }
557  } else {
558  CHECK(node.split_type() == SplitFeatureType::kCategorical);
559  }
560  Q.push(node.cleft());
561  Q.push(node.cright());
562  }
563  }
564  }
565  for (int i = 0; i < model.num_feature; ++i) {
566  std::copy(thresh_[i].begin(), thresh_[i].end(),
567  std::back_inserter(cut_pts[i]));
568  }
569  return cut_pts;
570 }
571 
573 .describe("A compiler with a recursive approach")
574 .set_body([](const CompilerParam& param) -> Compiler* {
575  if (param.quantize > 0) {
576  return new RecursiveCompiler<Quantize>(param);
577  } else {
578  return new RecursiveCompiler<NoQuantize>(param);
579  }
580  });
581 } // namespace compiler
582 } // namespace treelite
583 
584 namespace {
585 
586 
587 inline void
588 GroupPolicy::Init(const treelite::Model& model) {
589  this->num_output_group = model.num_output_group;
590  this->random_forest_flag = model.random_forest_flag;
591 }
592 
593 inline std::string
594 GroupPolicy::GroupQueryFunction() const {
595  return "return " + std::to_string(num_output_group) + ";";
596 }
597 
598 inline std::string
599 GroupPolicy::Accumulator() const {
600  if (num_output_group > 1) {
601  return std::string("float sum[") + std::to_string(num_output_group)
602  + "] = {0.0f};";
603  } else {
604  return "float sum = 0.0f;";
605  }
606 }
607 
608 inline std::string
609 GroupPolicy::AccumulateTranslationUnit(size_t unit_id) const {
610  if (num_output_group > 1) {
611  return std::string("predict_margin_multiclass_unit")
612  + std::to_string(unit_id) + "(data, sum);";
613  } else {
614  return std::string("sum += predict_margin_unit")
615  + std::to_string(unit_id) + "(data);";
616  }
617 }
618 
619 inline std::vector<std::string>
620 GroupPolicy::AccumulateLeaf(const treelite::Tree::Node& node,
621  size_t tree_id) const {
622  if (num_output_group > 1) {
623  if (random_forest_flag) {
624  // multi-class classification with random forest
625  const std::vector<treelite::tl_float>& leaf_vector = node.leaf_vector();
626  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group))
627  << "Ill-formed model: leaf vector must be of length [num_output_group]";
628  std::vector<std::string> lines;
629  lines.reserve(num_output_group);
630  for (int group_id = 0; group_id < num_output_group; ++group_id) {
631  lines.push_back(std::string("sum[") + std::to_string(group_id)
632  + "] += (float)"
633  + treelite::common::ToString(leaf_vector[group_id]) + ";");
634  }
635  return lines;
636  } else {
637  // multi-class classification with gradient boosted trees
638  const treelite::tl_float leaf_value = node.leaf_value();
639  return { std::string("sum[") + std::to_string(tree_id % num_output_group)
640  + "] += (float)" + treelite::common::ToString(leaf_value) + ";" };
641  }
642  } else {
643  const treelite::tl_float leaf_value = node.leaf_value();
644  return {std::string("sum += (float)")
645  + treelite::common::ToString(leaf_value) + ";" };
646  }
647 }
648 
649 inline std::vector<std::string>
650 GroupPolicy::Return() const {
651  if (num_output_group > 1) {
652  return {std::string("for (int i = 0; i < ")
653  + std::to_string(num_output_group) + "; ++i) {",
654  " result[i] += sum[i];",
655  "}" };
656  } else {
657  return { "return sum;" };
658  }
659 }
660 
661 inline std::vector<std::string>
662 GroupPolicy::FinalReturn(size_t num_tree, float global_bias) const {
663  if (num_output_group > 1) {
664  if (random_forest_flag) {
665  // multi-class classification with random forest
666  return {std::string("for (int i = 0; i < ")
667  + std::to_string(num_output_group) + "; ++i) {",
668  std::string(" result[i] = sum[i] / ")
669  + std::to_string(num_tree) + " + ("
670  + treelite::common::ToString(global_bias) + ");",
671  "}"};
672  } else {
673  // multi-class classification with gradient boosted trees
674  return {std::string("for (int i = 0; i < ")
675  + std::to_string(num_output_group) + "; ++i) {",
676  " result[i] = sum[i] + ("
677  + treelite::common::ToString(global_bias) + ");",
678  "}"};
679  }
680  } else {
681  if (random_forest_flag) {
682  return { std::string("return sum / ") + std::to_string(num_tree) + " + ("
683  + treelite::common::ToString(global_bias) + ");" };
684  } else {
685  return { std::string("return sum + (")
686  + treelite::common::ToString(global_bias) + ");" };
687  }
688  }
689 }
690 
691 inline std::string
692 GroupPolicy::Prototype() const {
693  if (num_output_group > 1) {
694  return "void predict_margin_multiclass(union Entry* data, float* result)";
695  } else {
696  return "float predict_margin(union Entry* data)";
697  }
698 }
699 
700 inline std::string
701 GroupPolicy::PrototypeTranslationUnit(size_t unit_id) const {
702  if (num_output_group > 1) {
703  return std::string("void predict_margin_multiclass_unit")
704  + std::to_string(unit_id) + "(union Entry* data, float* result)";
705  } else {
706  return std::string("float predict_margin_unit")
707  + std::to_string(unit_id) + "(union Entry* data)";
708  }
709 }
710 
711 } // namespace anonymous
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:360
plain code block containing one or more lines of code
Definition: semantic.h:118
branch annotator class
Definition: annotator.h:16
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
fundamental block in semantic model. All code blocks should inherit from this class.
Definition: semantic.h:65
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:352
parameters for tree compiler
Definition: param.h:16
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
ModelParam param
extra parameters
Definition: tree.h:365
model structure for tree
SemanticModel Compile(const Model &model) override
convert tree ensemble model into semantic model
Definition: recursive.cc:154
in-memory representation of a decision tree
Definition: tree.h:19
float global_bias
global bias of the model
Definition: tree.h:321
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:33
BranchHint
enum class to store branch annotation
Definition: semantic.h:17
tl_float threshold() const
Definition: tree.h:67
Interface of compiler that translates a tree ensemble model into a semantic model.
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:57
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
int cright() const
index of right child
Definition: tree.h:30
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:139
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
function block with a prototype and code body. Its prototype can optionally be registered with a func...
Definition: semantic.h:138
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:71
if-else statement with condition may store a branch hint (>50% or <50% likely)
Definition: semantic.h:194
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:363
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
Definition: common.h:33
a conditional expression
Definition: semantic.h:184
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
sequence of one or more code blocks
Definition: semantic.h:171
tl_float leaf_value() const
Definition: tree.h:50
Branch annotation tools.
Some useful utilities.
int cleft() const
index of left child
Definition: tree.h:26
tools to define prediction transform function
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:257
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
Building blocks for semantic model of tree prediction code.
translation unit is abstraction of a source file
Definition: semantic.h:72
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
semantic model consists of a header, function registry, and a list of translation units ...
Definition: semantic.h:90
int verbose
if >0, produce extra messages
Definition: param.h:34
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:357
Operator
comparison operators
Definition: base.h:23