13 #include <dmlc/registry.h> 30 const NumericAdapter& numeric_adapter)
31 : split_index(node.split_index()), default_left(node.default_left()),
32 op(node.comparison_op()), threshold(node.threshold()),
33 numeric_adapter(numeric_adapter) {}
35 NumericAdapter&& numeric_adapter)
36 : split_index(node.split_index()), default_left(node.default_left()),
37 op(node.comparison_op()), threshold(node.threshold()),
38 numeric_adapter(std::move(numeric_adapter)) {}
40 inline std::
string Compile()
const override {
41 const std::string bitmap
42 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
43 return ((default_left) ? (std::string(
"!(") + bitmap +
") || ")
44 : (std::string(
" (") + bitmap +
") && "))
45 + numeric_adapter(op, split_index, threshold);
53 NumericAdapter numeric_adapter;
59 : split_index(node.split_index()), default_left(node.default_left()),
60 categorical_bitmap(to_bitmap(node.left_categories())) {}
62 inline std::
string Compile()
const override {
63 const std::string bitmap
64 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
65 const std::string comp
66 = std::string(
"((") + std::to_string(categorical_bitmap)
67 +
"U >> (unsigned int)(data[" + std::to_string(split_index)
69 return ((default_left) ? (std::string(
"!(") + bitmap +
") || ")
70 : (std::string(
" (") + bitmap +
") && "))
71 + ((categorical_bitmap == 0) ? std::string(
"0") : comp);
77 uint64_t categorical_bitmap;
79 inline uint64_t to_bitmap(
const std::vector<uint8_t>& left_categories)
const {
81 for (uint8_t e : left_categories) {
82 CHECK_LT(e, 64) <<
"Cannot have more than 64 categories in a feature";
83 result |= (
static_cast<uint64_t
>(1) << e);
92 std::string GroupQueryFunction()
const;
93 std::string Accumulator()
const;
94 std::string AccumulateTranslationUnit(
size_t unit_id)
const;
96 size_t tree_id)
const;
97 std::vector<std::string> Return()
const;
98 std::vector<std::string> FinalReturn(
size_t num_tree,
float global_bias)
const;
99 std::string Prototype()
const;
100 std::string PrototypeTranslationUnit(
size_t unit_id)
const;
102 int num_output_group;
103 bool random_forest_flag;
111 DMLC_REGISTRY_FILE_TAG(recursive);
113 std::vector<std::vector<tl_float>> ExtractCutPoints(
const Model& model);
117 std::vector<std::vector<tl_float>> cut_pts;
118 std::vector<bool> is_categorical;
120 inline void Init(
const Model& model,
bool extract_cut_pts =
false) {
122 is_categorical.clear();
123 is_categorical.resize(num_feature,
false);
126 is_categorical[e] =
true;
129 if (extract_cut_pts) {
130 cut_pts = std::move(ExtractCutPoints(model));
135 template <
typename QuantizePolicy>
141 LOG(INFO) <<
"Using RecursiveCompiler";
156 info.Init(model, QuantizePolicy::QuantizeFlag());
157 QuantizePolicy::Init(std::move(info));
158 group_policy.Init(model);
160 std::vector<std::vector<size_t>> annotation;
161 bool annotate =
false;
164 std::unique_ptr<dmlc::Stream> fi(
165 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
166 annotator.
Load(fi.get());
167 annotation = annotator.
Get();
170 LOG(INFO) <<
"Using branch annotation file `" 177 sequence.Reserve(model.
trees.size() + 3);
178 sequence.PushBack(
PlainBlock(group_policy.Accumulator()));
179 sequence.PushBack(
PlainBlock(QuantizePolicy::Preprocessing()));
181 LOG(INFO) <<
"Parallel compilation enabled; member trees will be " 183 <<
" translation units.";
185 for (
size_t unit_id = 0; unit_id < nunit; ++unit_id) {
187 group_policy.AccumulateTranslationUnit(unit_id)));
190 LOG(INFO) <<
"Parallel compilation disabled; all member trees will be " 191 <<
"dump to a single source file. This may increase " 192 <<
"compilation time and memory usage.";
193 for (
size_t tree_id = 0; tree_id < model.
trees.size(); ++tree_id) {
195 if (!annotation.empty()) {
196 sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
197 annotation[tree_id])));
199 sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
204 sequence.PushBack(
PlainBlock(group_policy.FinalReturn(model.
trees.size(),
207 FunctionBlock query_func(
"size_t get_num_output_group(void)",
208 PlainBlock(group_policy.GroupQueryFunction()),
209 &semantic_model.function_registry,
true);
213 &semantic_model.function_registry,
true);
214 FunctionBlock pred_transform_func(PredTransformPrototype(
false),
215 PlainBlock(PredTransformFunction(model,
false)),
216 &semantic_model.function_registry,
true);
217 FunctionBlock pred_transform_batch_func(PredTransformPrototype(
true),
218 PlainBlock(PredTransformFunction(model,
true)),
219 &semantic_model.function_registry,
true);
221 std::move(sequence), &semantic_model.function_registry,
true);
223 main_file.Reserve(5);
224 main_file.PushBack(std::move(query_func));
225 main_file.PushBack(std::move(query_func2));
226 main_file.PushBack(std::move(pred_transform_func));
227 main_file.PushBack(std::move(pred_transform_batch_func));
228 main_file.PushBack(std::move(main_func));
229 auto file_preamble = QuantizePolicy::ConstantsPreamble();
230 semantic_model.units.emplace_back(
PlainBlock(file_preamble),
231 std::move(main_file));
235 const size_t unit_size = (model.
trees.size() + nunit - 1) / nunit;
236 for (
size_t unit_id = 0; unit_id < nunit; ++unit_id) {
237 const size_t tree_begin = unit_id * unit_size;
238 const size_t tree_end = std::min((unit_id + 1) * unit_size,
241 unit_seq.Reserve(tree_end - tree_begin + 2);
242 unit_seq.PushBack(
PlainBlock(group_policy.Accumulator()));
243 for (
size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
245 if (!annotation.empty()) {
246 unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
247 annotation[tree_id])));
249 unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
253 unit_seq.PushBack(
PlainBlock(group_policy.Return()));
254 FunctionBlock unit_func(group_policy.PrototypeTranslationUnit(unit_id),
256 &semantic_model.function_registry);
257 semantic_model.units.emplace_back(
PlainBlock(), std::move(unit_func));
260 auto header = QuantizePolicy::CommonHeader();
262 header.emplace_back();
263 #if defined(__clang__) || defined(__GNUC__) 265 header.emplace_back(
"#define LIKELY(x) __builtin_expect(!!(x), 1)");
266 header.emplace_back(
"#define UNLIKELY(x) __builtin_expect(!!(x), 0)");
268 header.emplace_back(
"#define LIKELY(x) (x)");
269 header.emplace_back(
"#define UNLIKELY(x) (x)");
272 semantic_model.common_header
273 = std::move(common::make_unique<PlainBlock>(header));
274 return semantic_model;
279 GroupPolicy group_policy;
281 std::unique_ptr<CodeBlock> WalkTree(
const Tree& tree,
size_t tree_id,
282 const std::vector<size_t>& counts)
const {
283 return WalkTree_(tree, tree_id, counts, 0);
286 std::unique_ptr<CodeBlock> WalkTree_(
const Tree& tree,
size_t tree_id,
287 const std::vector<size_t>& counts,
292 return std::unique_ptr<CodeBlock>(
new PlainBlock(
293 group_policy.AccumulateLeaf(node, tree_id)));
295 BranchHint branch_hint = BranchHint::kNone;
296 if (!counts.empty()) {
297 const size_t left_count = counts[node.
cleft()];
298 const size_t right_count = counts[node.
cright()];
299 branch_hint = (left_count > right_count) ? BranchHint::kLikely
300 : BranchHint::kUnlikely;
302 std::unique_ptr<Condition> condition(
nullptr);
303 if (node.
split_type() == SplitFeatureType::kNumerical) {
304 condition = common::make_unique<NumericSplitCondition>(node,
305 QuantizePolicy::NumericAdapter());
307 condition = common::make_unique<CategoricalSplitCondition>(node);
310 common::MoveUniquePtr(condition),
311 common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.
cleft())),
312 common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.
cright())),
325 this->info = std::move(info);
339 template <
typename... Args>
340 void Init(Args&&... args) {
341 MetadataStore::Init(std::forward<Args>(args)...);
343 NumericSplitCondition::NumericAdapter NumericAdapter()
const {
345 std::ostringstream oss;
346 if (!std::isfinite(threshold)) {
349 oss << (semantic::CompareWithOp(0.0, op, threshold) ?
"1" :
"0");
352 const std::streamsize ss = std::cout.precision();
353 oss <<
"data[" << split_index <<
"].fvalue " 354 << semantic::OpName(op) <<
" " 355 << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
357 << std::setprecision(ss);
362 std::vector<std::string> CommonHeader()
const {
363 return {
"#include <stdlib.h>",
364 "#include <string.h>",
366 "#include <stdint.h>",
373 std::vector<std::string> ConstantsPreamble()
const {
376 std::vector<std::string> Preprocessing()
const {
379 bool QuantizeFlag()
const {
386 template <
typename... Args>
387 void Init(Args&&... args) {
388 MetadataStore::Init(std::forward<Args>(args)...);
390 std::string(
"for (int i = 0; i < ")
391 + std::to_string(GetInfo().num_feature) +
"; ++i) {",
392 " if (data[i].missing != -1 && !is_categorical[i]) {",
393 " data[i].qvalue = quantize(data[i].fvalue, i);",
397 NumericSplitCondition::NumericAdapter NumericAdapter()
const {
398 const std::vector<std::vector<tl_float>>& cut_pts = GetInfo().cut_pts;
399 return [&cut_pts] (
Operator op,
unsigned split_index,
401 std::ostringstream oss;
402 const auto& v = cut_pts[split_index];
403 if (!std::isfinite(threshold)) {
406 oss << (semantic::CompareWithOp(0.0, op, threshold) ?
"1" :
"0");
408 auto loc = common::binary_search(v.begin(), v.end(), threshold);
409 CHECK(loc != v.end());
410 oss <<
"data[" << split_index <<
"].qvalue " << semantic::OpName(op)
411 <<
" " <<
static_cast<size_t>(loc - v.begin()) * 2;
416 std::vector<std::string> CommonHeader()
const {
417 return {
"#include <stdlib.h>",
418 "#include <string.h>",
420 "#include <stdint.h>",
428 std::vector<std::string> ConstantsPreamble()
const {
429 std::vector<std::string> ret;
430 ret.emplace_back(
"static const unsigned char is_categorical[] = {");
432 std::ostringstream oss, oss2;
435 const int num_feature = GetInfo().num_feature;
436 const auto& is_categorical = GetInfo().is_categorical;
437 for (
int fid = 0; fid < num_feature; ++fid) {
438 if (is_categorical[fid]) {
439 common::WrapText(&oss, &length,
"1", 80);
441 common::WrapText(&oss, &length,
"0", 80);
444 ret.push_back(oss.str());
445 ret.emplace_back(
"};");
448 ret.emplace_back(
"static const float threshold[] = {");
450 std::ostringstream oss, oss2;
453 for (
const auto& e : GetInfo().cut_pts) {
454 for (
const auto& value : e) {
455 oss2.clear(); oss2.str(std::string()); oss2 << value;
456 common::WrapText(&oss, &length, oss2.str(), 80);
459 ret.push_back(oss.str());
460 ret.emplace_back(
"};");
463 ret.emplace_back(
"static const int th_begin[] = {");
465 std::ostringstream oss, oss2;
469 for (
const auto& e : GetInfo().cut_pts) {
470 oss2.clear(); oss2.str(std::string()); oss2 << accum;
471 common::WrapText(&oss, &length, oss2.str(), 80);
474 ret.push_back(oss.str());
475 ret.emplace_back(
"};");
478 ret.emplace_back(
"static const int th_len[] = {");
480 std::ostringstream oss, oss2;
483 for (
const auto& e : GetInfo().cut_pts) {
484 oss2.clear(); oss2.str(std::string()); oss2 << e.size();
485 common::WrapText(&oss, &length, oss2.str(), 80);
487 ret.push_back(oss.str());
488 ret.emplace_back(
"};");
493 "static inline int quantize(float val, unsigned fid)",
495 {
"const float* array = &threshold[th_begin[fid]];",
496 "int len = th_len[fid];",
501 "if (val < array[0]) {",
504 "while (low + 1 < high) {",
505 " mid = (low + high) / 2;",
506 " mval = array[mid];",
507 " if (val == mval) {",
509 " } else if (val < mval) {",
515 "if (array[low] == val) {",
517 "} else if (high == len) {",
520 " return low * 2 + 1;",
521 "}"}),
nullptr).Compile();
522 ret.insert(ret.end(), func.begin(), func.end());
525 std::vector<std::string> Preprocessing()
const {
526 return quant_preamble;
528 bool QuantizeFlag()
const {
532 std::vector<std::string> quant_preamble;
535 inline std::vector<std::vector<tl_float>>
536 ExtractCutPoints(
const Model& model) {
537 std::vector<std::vector<tl_float>> cut_pts;
539 std::vector<std::set<tl_float>> thresh_;
542 for (
size_t i = 0; i < model.
trees.size(); ++i) {
547 const int nid = Q.front();
551 if (node.
split_type() == SplitFeatureType::kNumerical) {
554 if (std::isfinite(threshold)) {
555 thresh_[split_index].insert(threshold);
558 CHECK(node.
split_type() == SplitFeatureType::kCategorical);
560 Q.push(node.
cleft());
566 std::copy(thresh_[i].begin(), thresh_[i].end(),
567 std::back_inserter(cut_pts[i]));
573 .describe(
"A compiler with a recursive approach")
576 return new RecursiveCompiler<Quantize>(param);
578 return new RecursiveCompiler<NoQuantize>(param);
594 GroupPolicy::GroupQueryFunction()
const {
595 return "return " + std::to_string(num_output_group) +
";";
599 GroupPolicy::Accumulator()
const {
600 if (num_output_group > 1) {
601 return std::string(
"float sum[") + std::to_string(num_output_group)
604 return "float sum = 0.0f;";
609 GroupPolicy::AccumulateTranslationUnit(
size_t unit_id)
const {
610 if (num_output_group > 1) {
611 return std::string(
"predict_margin_multiclass_unit")
612 + std::to_string(unit_id) +
"(data, sum);";
614 return std::string(
"sum += predict_margin_unit")
615 + std::to_string(unit_id) +
"(data);";
619 inline std::vector<std::string>
621 size_t tree_id)
const {
622 if (num_output_group > 1) {
623 if (random_forest_flag) {
625 const std::vector<treelite::tl_float>& leaf_vector = node.
leaf_vector();
626 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group))
627 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
628 std::vector<std::string> lines;
629 lines.reserve(num_output_group);
630 for (
int group_id = 0; group_id < num_output_group; ++group_id) {
631 lines.push_back(std::string(
"sum[") + std::to_string(group_id)
633 + treelite::common::ToString(leaf_vector[group_id]) +
";");
639 return { std::string(
"sum[") + std::to_string(tree_id % num_output_group)
640 +
"] += (float)" + treelite::common::ToString(leaf_value) +
";" };
644 return {std::string(
"sum += (float)")
645 + treelite::common::ToString(leaf_value) +
";" };
649 inline std::vector<std::string>
650 GroupPolicy::Return()
const {
651 if (num_output_group > 1) {
652 return {std::string(
"for (int i = 0; i < ")
653 + std::to_string(num_output_group) +
"; ++i) {",
654 " result[i] += sum[i];",
657 return {
"return sum;" };
661 inline std::vector<std::string>
662 GroupPolicy::FinalReturn(
size_t num_tree,
float global_bias)
const {
663 if (num_output_group > 1) {
664 if (random_forest_flag) {
666 return {std::string(
"for (int i = 0; i < ")
667 + std::to_string(num_output_group) +
"; ++i) {",
668 std::string(
" result[i] = sum[i] / ")
669 + std::to_string(num_tree) +
" + (" 670 + treelite::common::ToString(global_bias) +
");",
674 return {std::string(
"for (int i = 0; i < ")
675 + std::to_string(num_output_group) +
"; ++i) {",
676 " result[i] = sum[i] + (" 677 + treelite::common::ToString(global_bias) +
");",
681 if (random_forest_flag) {
682 return { std::string(
"return sum / ") + std::to_string(num_tree) +
" + (" 683 + treelite::common::ToString(global_bias) +
");" };
685 return { std::string(
"return sum + (")
686 + treelite::common::ToString(global_bias) +
");" };
692 GroupPolicy::Prototype()
const {
693 if (num_output_group > 1) {
694 return "void predict_margin_multiclass(union Entry* data, float* result)";
696 return "float predict_margin(union Entry* data)";
701 GroupPolicy::PrototypeTranslationUnit(
size_t unit_id)
const {
702 if (num_output_group > 1) {
703 return std::string(
"void predict_margin_multiclass_unit")
704 + std::to_string(unit_id) +
"(union Entry* data, float* result)";
706 return std::string(
"float predict_margin_unit")
707 + std::to_string(unit_id) +
"(union Entry* data)";
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
plain code block containing one or more lines of code
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
fundamental block in semantic model. All code blocks should inherit from this class.
std::vector< Tree > trees
member trees
parameters for tree compiler
unsigned split_index() const
feature index of split condition
ModelParam param
extra parameters
SemanticModel Compile(const Model &model) override
convert tree ensemble model into semantic model
in-memory representation of a decision tree
float global_bias
global bias of the model
Parameters for tree compiler.
BranchHint
enum class to store branch annotation
tl_float threshold() const
Interface of compiler that translates a tree ensemble model into a semantic model.
const std::vector< tl_float > & leaf_vector() const
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
int cright() const
index of right child
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
function block with a prototype and code body. Its prototype can optionally be registered with a func...
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
if-else statement with condition may store a branch hint (>50% or <50% likely)
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
sequence of one or more code blocks
tl_float leaf_value() const
int cleft() const
index of left child
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
int quantize
whether to quantize threshold points (0: no, >0: yes)
Building blocks for semantic model of tree prediction code.
translation unit is abstraction of a source file
SplitFeatureType split_type() const
get feature split type
semantic model consists of a header, function registry, and a list of translation units ...
int verbose
if >0, produce extra messages
bool is_leaf() const
whether current node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators