treelite
recursive.cc
Go to the documentation of this file.
1 
8 #include <treelite/common.h>
9 #include <treelite/annotator.h>
10 #include <treelite/compiler.h>
11 #include <treelite/tree.h>
12 #include <treelite/semantic.h>
13 #include <dmlc/registry.h>
14 #include <queue>
15 #include <algorithm>
16 #include <iterator>
17 #include <cmath>
18 #include "param.h"
19 #include "pred_transform.h"
20 
21 namespace {
22 
23 class NumericSplitCondition : public treelite::semantic::Condition {
24  public:
25  using NumericAdapter
26  = std::function<std::string(treelite::Operator, unsigned,
28  explicit NumericSplitCondition(const treelite::Tree::Node& node,
29  const NumericAdapter& numeric_adapter)
30  : split_index(node.split_index()), default_left(node.default_left()),
31  op(node.comparison_op()), threshold(node.threshold()),
32  numeric_adapter(numeric_adapter) {}
33  explicit NumericSplitCondition(const treelite::Tree::Node& node,
34  NumericAdapter&& numeric_adapter)
35  : split_index(node.split_index()), default_left(node.default_left()),
36  op(node.comparison_op()), threshold(node.threshold()),
37  numeric_adapter(std::move(numeric_adapter)) {}
38  CLONEABLE_BOILERPLATE(NumericSplitCondition)
39  inline std::string Compile() const override {
40  const std::string bitmap
41  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
42  return ((default_left) ? (std::string("!(") + bitmap + ") || ")
43  : (std::string(" (") + bitmap + ") && "))
44  + numeric_adapter(op, split_index, threshold);
45  }
46 
47  private:
48  unsigned split_index;
49  bool default_left;
51  treelite::tl_float threshold;
52  NumericAdapter numeric_adapter;
53 };
54 
55 class CategoricalSplitCondition : public treelite::semantic::Condition {
56  public:
57  explicit CategoricalSplitCondition(const treelite::Tree::Node& node)
58  : split_index(node.split_index()), default_left(node.default_left()),
59  categorical_bitmap(to_bitmap(node.left_categories())) {}
60  CLONEABLE_BOILERPLATE(CategoricalSplitCondition)
61  inline std::string Compile() const override {
62  const std::string bitmap
63  = std::string("data[") + std::to_string(split_index) + "].missing != -1";
64  const std::string comp
65  = std::string("((") + std::to_string(categorical_bitmap)
66  + "U >> (unsigned int)(data[" + std::to_string(split_index)
67  + "].fvalue)) & 1)";
68  return ((default_left) ? (std::string("!(") + bitmap + ") || ")
69  : (std::string(" (") + bitmap + ") && "))
70  + ((categorical_bitmap == 0) ? std::string("0") : comp);
71  }
72 
73  private:
74  unsigned split_index;
75  bool default_left;
76  uint64_t categorical_bitmap;
77 
78  inline uint64_t to_bitmap(const std::vector<uint8_t>& left_categories) const {
79  uint64_t result = 0;
80  for (uint8_t e : left_categories) {
81  CHECK_LT(e, 64) << "Cannot have more than 64 categories in a feature";
82  result |= (static_cast<uint64_t>(1) << e);
83  }
84  return result;
85  }
86 };
87 
88 struct GroupPolicy {
89  void Init(const treelite::Model& model);
90 
91  std::string GroupQueryFunction() const;
92  std::string Accumulator() const;
93  std::string AccumulateTranslationUnit(size_t unit_id) const;
94  std::vector<std::string> AccumulateLeaf(const treelite::Tree::Node& node,
95  size_t tree_id) const;
96  std::vector<std::string> Return() const;
97  std::vector<std::string> FinalReturn(size_t num_tree, float global_bias) const;
98  std::string Prototype() const;
99  std::string PrototypeTranslationUnit(size_t unit_id) const;
100 
101  int num_output_group;
102  bool random_forest_flag;
103 };
104 
105 } // namespace anonymous
106 
107 namespace treelite {
108 namespace compiler {
109 
110 DMLC_REGISTRY_FILE_TAG(recursive);
111 
112 std::vector<std::vector<tl_float>> ExtractCutPoints(const Model& model);
113 
114 struct Metadata {
115  int num_feature;
116  std::vector<std::vector<tl_float>> cut_pts;
117  std::vector<bool> is_categorical;
118 
119  inline void Init(const Model& model, bool extract_cut_pts = false) {
120  num_feature = model.num_feature;
121  is_categorical.clear();
122  is_categorical.resize(num_feature, false);
123  for (const Tree& tree : model.trees) {
124  for (unsigned e : tree.GetCategoricalFeatures()) {
125  is_categorical[e] = true;
126  }
127  }
128  if (extract_cut_pts) {
129  cut_pts = std::move(ExtractCutPoints(model));
130  }
131  }
132 };
133 
134 template <typename QuantizePolicy>
135 class RecursiveCompiler : public Compiler, private QuantizePolicy {
136  public:
137  explicit RecursiveCompiler(const CompilerParam& param)
138  : param(param) {
139  if (param.verbose > 0) {
140  LOG(INFO) << "Using RecursiveCompiler";
141  }
142  }
143 
152 
153  SemanticModel Compile(const Model& model) override {
154  Metadata info;
155  info.Init(model, QuantizePolicy::QuantizeFlag());
156  QuantizePolicy::Init(std::move(info));
157  group_policy.Init(model);
158 
159  std::vector<std::vector<size_t>> annotation;
160  bool annotate = false;
161  if (param.annotate_in != "NULL") {
162  BranchAnnotator annotator;
163  std::unique_ptr<dmlc::Stream> fi(
164  dmlc::Stream::Create(param.annotate_in.c_str(), "r"));
165  annotator.Load(fi.get());
166  annotation = annotator.Get();
167  annotate = true;
168  if (param.verbose > 0) {
169  LOG(INFO) << "Using branch annotation file `"
170  << param.annotate_in << "\"";
171  }
172  }
173 
174  SemanticModel semantic_model;
175  SequenceBlock sequence;
176  sequence.Reserve(model.trees.size() + 3);
177  sequence.PushBack(PlainBlock(group_policy.Accumulator()));
178  sequence.PushBack(PlainBlock(QuantizePolicy::Preprocessing()));
179  if (param.parallel_comp > 0) {
180  LOG(INFO) << "Parallel compilation enabled; member trees will be "
181  << "divided into " << param.parallel_comp
182  << " translation units.";
183  const size_t nunit = param.parallel_comp; // # translation units
184  for (size_t unit_id = 0; unit_id < nunit; ++unit_id) {
185  sequence.PushBack(PlainBlock(
186  group_policy.AccumulateTranslationUnit(unit_id)));
187  }
188  } else {
189  LOG(INFO) << "Parallel compilation disabled; all member trees will be "
190  << "dump to a single source file. This may increase "
191  << "compilation time and memory usage.";
192  for (size_t tree_id = 0; tree_id < model.trees.size(); ++tree_id) {
193  const Tree& tree = model.trees[tree_id];
194  if (!annotation.empty()) {
195  sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
196  annotation[tree_id])));
197  } else {
198  sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
199  {})));
200  }
201  }
202  }
203  sequence.PushBack(PlainBlock(group_policy.FinalReturn(model.trees.size(),
204  model.param.global_bias)));
205 
206  FunctionBlock query_func("size_t get_num_output_group(void)",
207  PlainBlock(group_policy.GroupQueryFunction()),
208  &semantic_model.function_registry, true);
209  FunctionBlock query_func2("size_t get_num_feature(void)",
210  PlainBlock(std::string("return ") +
211  std::to_string(model.num_feature)+";"),
212  &semantic_model.function_registry, true);
213  FunctionBlock pred_transform_func(PredTransformPrototype(false),
214  PlainBlock(PredTransformFunction(model, false)),
215  &semantic_model.function_registry, true);
216  FunctionBlock pred_transform_batch_func(PredTransformPrototype(true),
217  PlainBlock(PredTransformFunction(model, true)),
218  &semantic_model.function_registry, true);
219  FunctionBlock main_func(group_policy.Prototype(),
220  std::move(sequence), &semantic_model.function_registry, true);
221  SequenceBlock main_file;
222  main_file.Reserve(5);
223  main_file.PushBack(std::move(query_func));
224  main_file.PushBack(std::move(query_func2));
225  main_file.PushBack(std::move(pred_transform_func));
226  main_file.PushBack(std::move(pred_transform_batch_func));
227  main_file.PushBack(std::move(main_func));
228  auto file_preamble = QuantizePolicy::ConstantsPreamble();
229  semantic_model.units.emplace_back(PlainBlock(file_preamble),
230  std::move(main_file));
231 
232  if (param.parallel_comp > 0) {
233  const size_t nunit = param.parallel_comp;
234  const size_t unit_size = (model.trees.size() + nunit - 1) / nunit;
235  for (size_t unit_id = 0; unit_id < nunit; ++unit_id) {
236  const size_t tree_begin = unit_id * unit_size;
237  const size_t tree_end = std::min((unit_id + 1) * unit_size,
238  model.trees.size());
239  SequenceBlock unit_seq;
240  unit_seq.Reserve(tree_end - tree_begin + 2);
241  unit_seq.PushBack(PlainBlock(group_policy.Accumulator()));
242  for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
243  const Tree& tree = model.trees[tree_id];
244  if (!annotation.empty()) {
245  unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
246  annotation[tree_id])));
247  } else {
248  unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
249  {})));
250  }
251  }
252  unit_seq.PushBack(PlainBlock(group_policy.Return()));
253  FunctionBlock unit_func(group_policy.PrototypeTranslationUnit(unit_id),
254  std::move(unit_seq),
255  &semantic_model.function_registry);
256  semantic_model.units.emplace_back(PlainBlock(), std::move(unit_func));
257  }
258  }
259  auto header = QuantizePolicy::CommonHeader();
260  if (annotate) {
261  header.emplace_back();
262 #if defined(__clang__) || defined(__GNUC__)
263  // only gcc and clang support __builtin_expect intrinsic
264  header.emplace_back("#define LIKELY(x) __builtin_expect(!!(x), 1)");
265  header.emplace_back("#define UNLIKELY(x) __builtin_expect(!!(x), 0)");
266 #else
267  header.emplace_back("#define LIKELY(x) (x)");
268  header.emplace_back("#define UNLIKELY(x) (x)");
269 #endif
270  }
271  semantic_model.common_header
272  = std::move(common::make_unique<PlainBlock>(header));
273  return semantic_model;
274  }
275 
276  private:
277  CompilerParam param;
278  GroupPolicy group_policy;
279 
280  std::unique_ptr<CodeBlock> WalkTree(const Tree& tree, size_t tree_id,
281  const std::vector<size_t>& counts) const {
282  return WalkTree_(tree, tree_id, counts, 0);
283  }
284 
285  std::unique_ptr<CodeBlock> WalkTree_(const Tree& tree, size_t tree_id,
286  const std::vector<size_t>& counts,
287  int nid) const {
288  using semantic::BranchHint;
289  const Tree::Node& node = tree[nid];
290  if (node.is_leaf()) {
291  return std::unique_ptr<CodeBlock>(new PlainBlock(
292  group_policy.AccumulateLeaf(node, tree_id)));
293  } else {
294  BranchHint branch_hint = BranchHint::kNone;
295  if (!counts.empty()) {
296  const size_t left_count = counts[node.cleft()];
297  const size_t right_count = counts[node.cright()];
298  branch_hint = (left_count > right_count) ? BranchHint::kLikely
299  : BranchHint::kUnlikely;
300  }
301  std::unique_ptr<Condition> condition(nullptr);
302  if (node.split_type() == SplitFeatureType::kNumerical) {
303  condition = common::make_unique<NumericSplitCondition>(node,
304  QuantizePolicy::NumericAdapter());
305  } else {
306  condition = common::make_unique<CategoricalSplitCondition>(node);
307  }
308  return std::unique_ptr<CodeBlock>(new IfElseBlock(
309  common::MoveUniquePtr(condition),
310  common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.cleft())),
311  common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.cright())),
312  branch_hint)
313  );
314  }
315  }
316 };
317 
319  protected:
320  void Init(const Metadata& info) {
321  this->info = info;
322  }
323  void Init(Metadata&& info) {
324  this->info = std::move(info);
325  }
326  const Metadata& GetInfo() const {
327  return info;
328  }
329  MetadataStore() = default;
330  MetadataStore(const MetadataStore& other) = default;
331  MetadataStore(MetadataStore&& other) = default;
332  private:
333  Metadata info;
334 };
335 
336 class NoQuantize : private MetadataStore {
337  protected:
338  template <typename... Args>
339  void Init(Args&&... args) {
340  MetadataStore::Init(std::forward<Args>(args)...);
341  }
342  NumericSplitCondition::NumericAdapter NumericAdapter() const {
343  return [] (Operator op, unsigned split_index, tl_float threshold) {
344  std::ostringstream oss;
345  if (!std::isfinite(threshold)) {
346  // According to IEEE 754, the result of comparison [lhs] < infinity
347  // must be identical for all finite [lhs]. Same goes for operator >.
348  oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
349  } else {
350  oss << "data[" << split_index << "].fvalue "
351  << semantic::OpName(op) << " " << threshold;
352  }
353  return oss.str();
354  };
355  }
356  std::vector<std::string> CommonHeader() const {
357  return {"#include <stdlib.h>",
358  "#include <string.h>",
359  "#include <math.h>",
360  "#include <stdint.h>",
361  "",
362  "union Entry {",
363  " int missing;",
364  " float fvalue;",
365  "};"};
366  }
367  std::vector<std::string> ConstantsPreamble() const {
368  return {};
369  }
370  std::vector<std::string> Preprocessing() const {
371  return {};
372  }
373  bool QuantizeFlag() const {
374  return false;
375  }
376 };
377 
378 class Quantize : private MetadataStore {
379  protected:
380  template <typename... Args>
381  void Init(Args&&... args) {
382  MetadataStore::Init(std::forward<Args>(args)...);
383  quant_preamble = {
384  std::string("for (int i = 0; i < ")
385  + std::to_string(GetInfo().num_feature) + "; ++i) {",
386  " if (data[i].missing != -1 && !is_categorical[i]) {",
387  " data[i].qvalue = quantize(data[i].fvalue, i);",
388  " }",
389  "}"};
390  }
391  NumericSplitCondition::NumericAdapter NumericAdapter() const {
392  const std::vector<std::vector<tl_float>>& cut_pts = GetInfo().cut_pts;
393  return [&cut_pts] (Operator op, unsigned split_index,
394  tl_float threshold) {
395  std::ostringstream oss;
396  const auto& v = cut_pts[split_index];
397  if (!std::isfinite(threshold)) {
398  // According to IEEE 754, the result of comparison [lhs] < infinity
399  // must be identical for all finite [lhs]. Same goes for operator >.
400  oss << (semantic::CompareWithOp(0.0, op, threshold) ? "1" : "0");
401  } else {
402  auto loc = common::binary_search(v.begin(), v.end(), threshold);
403  CHECK(loc != v.end());
404  oss << "data[" << split_index << "].qvalue " << semantic::OpName(op)
405  << " " << static_cast<size_t>(loc - v.begin()) * 2;
406  }
407  return oss.str();
408  };
409  }
410  std::vector<std::string> CommonHeader() const {
411  return {"#include <stdlib.h>",
412  "#include <string.h>",
413  "#include <math.h>",
414  "#include <stdint.h>",
415  "",
416  "union Entry {",
417  " int missing;",
418  " float fvalue;",
419  " int qvalue;",
420  "};"};
421  }
422  std::vector<std::string> ConstantsPreamble() const {
423  std::vector<std::string> ret;
424  ret.emplace_back("static const unsigned char is_categorical[] = {");
425  {
426  std::ostringstream oss, oss2;
427  size_t length = 2;
428  oss << " ";
429  const int num_feature = GetInfo().num_feature;
430  const auto& is_categorical = GetInfo().is_categorical;
431  for (int fid = 0; fid < num_feature; ++fid) {
432  if (is_categorical[fid]) {
433  common::WrapText(&oss, &length, "1", 80);
434  } else {
435  common::WrapText(&oss, &length, "0", 80);
436  }
437  }
438  ret.push_back(oss.str());
439  ret.emplace_back("};");
440  ret.emplace_back();
441  }
442  ret.emplace_back("static const float threshold[] = {");
443  {
444  std::ostringstream oss, oss2;
445  size_t length = 2;
446  oss << " ";
447  for (const auto& e : GetInfo().cut_pts) {
448  for (const auto& value : e) {
449  oss2.clear(); oss2.str(std::string()); oss2 << value;
450  common::WrapText(&oss, &length, oss2.str(), 80);
451  }
452  }
453  ret.push_back(oss.str());
454  ret.emplace_back("};");
455  ret.emplace_back();
456  }
457  ret.emplace_back("static const int th_begin[] = {");
458  {
459  std::ostringstream oss, oss2;
460  size_t length = 2;
461  size_t accum = 0;
462  oss << " ";
463  for (const auto& e : GetInfo().cut_pts) {
464  oss2.clear(); oss2.str(std::string()); oss2 << accum;
465  common::WrapText(&oss, &length, oss2.str(), 80);
466  accum += e.size();
467  }
468  ret.push_back(oss.str());
469  ret.emplace_back("};");
470  ret.emplace_back();
471  }
472  ret.emplace_back("static const int th_len[] = {");
473  {
474  std::ostringstream oss, oss2;
475  size_t length = 2;
476  oss << " ";
477  for (const auto& e : GetInfo().cut_pts) {
478  oss2.clear(); oss2.str(std::string()); oss2 << e.size();
479  common::WrapText(&oss, &length, oss2.str(), 80);
480  }
481  ret.push_back(oss.str());
482  ret.emplace_back("};");
483  ret.emplace_back();
484  }
485 
486  auto func = semantic::FunctionBlock(
487  "static inline int quantize(float val, unsigned fid)",
489  {"const float* array = &threshold[th_begin[fid]];",
490  "int len = th_len[fid];",
491  "int low = 0;",
492  "int high = len;",
493  "int mid;",
494  "float mval;",
495  "if (val < array[0]) {",
496  " return -10;",
497  "}",
498  "while (low + 1 < high) {",
499  " mid = (low + high) / 2;",
500  " mval = array[mid];",
501  " if (val == mval) {",
502  " return mid * 2;",
503  " } else if (val < mval) {",
504  " high = mid;",
505  " } else {",
506  " low = mid;",
507  " }",
508  "}",
509  "if (array[low] == val) {",
510  " return low * 2;",
511  "} else if (high == len) {",
512  " return len * 2;",
513  "} else {",
514  " return low * 2 + 1;",
515  "}"}), nullptr).Compile();
516  ret.insert(ret.end(), func.begin(), func.end());
517  return ret;
518  }
519  std::vector<std::string> Preprocessing() const {
520  return quant_preamble;
521  }
522  bool QuantizeFlag() const {
523  return true;
524  }
525  private:
526  std::vector<std::string> quant_preamble;
527 };
528 
529 inline std::vector<std::vector<tl_float>>
530 ExtractCutPoints(const Model& model) {
531  std::vector<std::vector<tl_float>> cut_pts;
532 
533  std::vector<std::set<tl_float>> thresh_;
534  cut_pts.resize(model.num_feature);
535  thresh_.resize(model.num_feature);
536  for (size_t i = 0; i < model.trees.size(); ++i) {
537  const Tree& tree = model.trees[i];
538  std::queue<int> Q;
539  Q.push(0);
540  while (!Q.empty()) {
541  const int nid = Q.front();
542  const Tree::Node& node = tree[nid];
543  Q.pop();
544  if (!node.is_leaf()) {
545  if (node.split_type() == SplitFeatureType::kNumerical) {
546  const tl_float threshold = node.threshold();
547  const unsigned split_index = node.split_index();
548  if (std::isfinite(threshold)) { // ignore infinity
549  thresh_[split_index].insert(threshold);
550  }
551  } else {
552  CHECK(node.split_type() == SplitFeatureType::kCategorical);
553  }
554  Q.push(node.cleft());
555  Q.push(node.cright());
556  }
557  }
558  }
559  for (int i = 0; i < model.num_feature; ++i) {
560  std::copy(thresh_[i].begin(), thresh_[i].end(),
561  std::back_inserter(cut_pts[i]));
562  }
563  return cut_pts;
564 }
565 
567 .describe("A compiler with a recursive approach")
568 .set_body([](const CompilerParam& param) -> Compiler* {
569  if (param.quantize > 0) {
570  return new RecursiveCompiler<Quantize>(param);
571  } else {
572  return new RecursiveCompiler<NoQuantize>(param);
573  }
574  });
575 } // namespace compiler
576 } // namespace treelite
577 
578 namespace {
579 
580 
581 inline void
582 GroupPolicy::Init(const treelite::Model& model) {
583  this->num_output_group = model.num_output_group;
584  this->random_forest_flag = model.random_forest_flag;
585 }
586 
587 inline std::string
588 GroupPolicy::GroupQueryFunction() const {
589  return "return " + std::to_string(num_output_group) + ";";
590 }
591 
592 inline std::string
593 GroupPolicy::Accumulator() const {
594  if (num_output_group > 1) {
595  return std::string("float sum[") + std::to_string(num_output_group)
596  + "] = {0.0f};";
597  } else {
598  return "float sum = 0.0f;";
599  }
600 }
601 
602 inline std::string
603 GroupPolicy::AccumulateTranslationUnit(size_t unit_id) const {
604  if (num_output_group > 1) {
605  return std::string("predict_margin_multiclass_unit")
606  + std::to_string(unit_id) + "(data, sum);";
607  } else {
608  return std::string("sum += predict_margin_unit")
609  + std::to_string(unit_id) + "(data);";
610  }
611 }
612 
613 inline std::vector<std::string>
614 GroupPolicy::AccumulateLeaf(const treelite::Tree::Node& node,
615  size_t tree_id) const {
616  if (num_output_group > 1) {
617  if (random_forest_flag) {
618  // multi-class classification with random forest
619  const std::vector<treelite::tl_float>& leaf_vector = node.leaf_vector();
620  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group))
621  << "Ill-formed model: leaf vector must be of length [num_output_group]";
622  std::vector<std::string> lines;
623  lines.reserve(num_output_group);
624  for (int group_id = 0; group_id < num_output_group; ++group_id) {
625  lines.push_back(std::string("sum[") + std::to_string(group_id)
626  + "] += (float)"
627  + treelite::common::ToString(leaf_vector[group_id]) + ";");
628  }
629  return lines;
630  } else {
631  // multi-class classification with gradient boosted trees
632  const treelite::tl_float leaf_value = node.leaf_value();
633  return { std::string("sum[") + std::to_string(tree_id % num_output_group)
634  + "] += (float)" + treelite::common::ToString(leaf_value) + ";" };
635  }
636  } else {
637  const treelite::tl_float leaf_value = node.leaf_value();
638  return {std::string("sum += (float)")
639  + treelite::common::ToString(leaf_value) + ";" };
640  }
641 }
642 
643 inline std::vector<std::string>
644 GroupPolicy::Return() const {
645  if (num_output_group > 1) {
646  return {std::string("for (int i = 0; i < ")
647  + std::to_string(num_output_group) + "; ++i) {",
648  " result[i] += sum[i];",
649  "}" };
650  } else {
651  return { "return sum;" };
652  }
653 }
654 
655 inline std::vector<std::string>
656 GroupPolicy::FinalReturn(size_t num_tree, float global_bias) const {
657  if (num_output_group > 1) {
658  if (random_forest_flag) {
659  // multi-class classification with random forest
660  return {std::string("for (int i = 0; i < ")
661  + std::to_string(num_output_group) + "; ++i) {",
662  std::string(" result[i] = sum[i] / ")
663  + std::to_string(num_tree) + " + ("
664  + treelite::common::ToString(global_bias) + ");",
665  "}"};
666  } else {
667  // multi-class classification with gradient boosted trees
668  return {std::string("for (int i = 0; i < ")
669  + std::to_string(num_output_group) + "; ++i) {",
670  " result[i] = sum[i] + ("
671  + treelite::common::ToString(global_bias) + ");",
672  "}"};
673  }
674  } else {
675  if (random_forest_flag) {
676  return { std::string("return sum / ") + std::to_string(num_tree) + " + ("
677  + treelite::common::ToString(global_bias) + ");" };
678  } else {
679  return { std::string("return sum + (")
680  + treelite::common::ToString(global_bias) + ");" };
681  }
682  }
683 }
684 
685 inline std::string
686 GroupPolicy::Prototype() const {
687  if (num_output_group > 1) {
688  return "void predict_margin_multiclass(union Entry* data, float* result)";
689  } else {
690  return "float predict_margin(union Entry* data)";
691  }
692 }
693 
694 inline std::string
695 GroupPolicy::PrototypeTranslationUnit(size_t unit_id) const {
696  if (num_output_group > 1) {
697  return std::string("void predict_margin_multiclass_unit")
698  + std::to_string(unit_id) + "(union Entry* data, float* result)";
699  } else {
700  return std::string("float predict_margin_unit")
701  + std::to_string(unit_id) + "(union Entry* data)";
702  }
703 }
704 
705 } // namespace anonymous
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:360
plain code block containing one or more lines of code
Definition: semantic.h:118
branch annotator class
Definition: annotator.h:16
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
fundamental block in semantic model. All code blocks should inherit from this class.
Definition: semantic.h:65
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:352
parameters for tree compiler
Definition: param.h:16
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
ModelParam param
extra parameters
Definition: tree.h:365
model structure for tree
SemanticModel Compile(const Model &model) override
convert tree ensemble model into semantic model
Definition: recursive.cc:153
in-memory representation of a decision tree
Definition: tree.h:19
float global_bias
global bias of the model
Definition: tree.h:321
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:33
BranchHint
enum class to store branch annotation
Definition: semantic.h:17
tl_float threshold() const
Definition: tree.h:67
Interface of compiler that translates a tree ensemble model into a semantic model.
const std::vector< tl_float > & leaf_vector() const
Definition: tree.h:57
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
Definition: annotator.h:51
int cright() const
index of right child
Definition: tree.h:30
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
Definition: annotator.cc:139
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
function block with a prototype and code body. Its prototype can optionally be registered with a func...
Definition: semantic.h:138
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:71
if-else statement with condition may store a branch hint (>50% or <50% likely)
Definition: semantic.h:194
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:363
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
Definition: common.h:32
a conditional expression
Definition: semantic.h:184
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
sequence of one or more code blocks
Definition: semantic.h:171
tl_float leaf_value() const
Definition: tree.h:50
Branch annotation tools.
Some useful utilities.
int cleft() const
index of left child
Definition: tree.h:26
tools to define prediction transform function
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
Definition: tree.h:257
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
Building blocks for semantic model of tree prediction code.
translation unit is abstraction of a source file
Definition: semantic.h:72
SplitFeatureType split_type() const
get feature split type
Definition: tree.h:91
semantic model consists of a header, function registry, and a list of translation units ...
Definition: semantic.h:90
int verbose
if >0, produce extra messages
Definition: param.h:34
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:357
Operator
comparison operators
Definition: base.h:23