13 #include <dmlc/registry.h> 30 const NumericAdapter& numeric_adapter)
31 : split_index(node.split_index()), default_left(node.default_left()),
32 op(node.comparison_op()), threshold(node.threshold()),
33 numeric_adapter(numeric_adapter) {}
35 NumericAdapter&& numeric_adapter)
36 : split_index(node.split_index()), default_left(node.default_left()),
37 op(node.comparison_op()), threshold(node.threshold()),
38 numeric_adapter(std::move(numeric_adapter)) {}
40 inline std::
string Compile()
const override {
41 const std::string bitmap
42 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
43 return ((default_left) ? (std::string(
"!(") + bitmap +
") || ")
44 : (std::string(
" (") + bitmap +
") && "))
45 + numeric_adapter(op, split_index, threshold);
53 NumericAdapter numeric_adapter;
59 : split_index(node.split_index()), default_left(node.default_left()),
60 categorical_bitmap(to_bitmap(node.left_categories())) {}
62 inline std::
string Compile()
const override {
63 const std::string bitmap
64 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
65 CHECK_GE(categorical_bitmap.size(), 1);
66 std::ostringstream comp;
67 comp <<
"(tmp = (unsigned int)(data[" << split_index <<
"].fvalue) ), " 68 <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 69 << categorical_bitmap[0] <<
"U >> tmp) & 1) )";
70 for (
size_t i = 1; i < categorical_bitmap.size(); ++i) {
71 comp <<
" || (tmp >= " << (i * 64)
72 <<
" && tmp < " << ((i + 1) * 64)
73 <<
" && (( (uint64_t)" << categorical_bitmap[i]
74 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
76 bool all_zeros =
true;
77 for (uint64_t e : categorical_bitmap) {
78 all_zeros &= (e == 0);
80 return ((default_left) ? (std::string(
"!(") + bitmap +
") || (")
81 : (std::string(
" (") + bitmap +
") && ("))
82 + (all_zeros ? std::string(
"0") : comp.str()) +
")";
88 std::vector<uint64_t> categorical_bitmap;
90 inline std::vector<uint64_t> to_bitmap(
const std::vector<uint32_t>& left_categories)
const {
91 const size_t num_left_categories = left_categories.size();
92 const uint32_t max_left_category = left_categories[num_left_categories - 1];
93 std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
94 for (
size_t i = 0; i < left_categories.size(); ++i) {
95 const uint32_t cat = left_categories[i];
96 const size_t idx = cat / 64;
97 const uint32_t offset = cat % 64;
98 bitmap[idx] |= (
static_cast<uint64_t
>(1) << offset);
107 std::string GroupQueryFunction()
const;
108 std::string Accumulator()
const;
109 std::string AccumulateTranslationUnit(
size_t unit_id)
const;
111 size_t tree_id)
const;
112 std::vector<std::string> Return()
const;
113 std::vector<std::string> FinalReturn(
size_t num_tree,
float global_bias)
const;
114 std::string Prototype()
const;
115 std::string PrototypeTranslationUnit(
size_t unit_id)
const;
117 int num_output_group;
118 bool random_forest_flag;
126 DMLC_REGISTRY_FILE_TAG(recursive);
128 std::vector<std::vector<tl_float>> ExtractCutPoints(
const Model& model);
132 std::vector<std::vector<tl_float>> cut_pts;
133 std::vector<bool> is_categorical;
135 inline void Init(
const Model& model,
bool extract_cut_pts =
false) {
137 is_categorical.clear();
138 is_categorical.resize(num_feature,
false);
141 is_categorical[e] =
true;
144 if (extract_cut_pts) {
145 cut_pts = std::move(ExtractCutPoints(model));
150 template <
typename QuantizePolicy>
156 LOG(INFO) <<
"Using RecursiveCompiler";
171 info.Init(model, QuantizePolicy::QuantizeFlag());
172 QuantizePolicy::Init(std::move(info));
173 group_policy.Init(model);
175 std::vector<std::vector<size_t>> annotation;
176 bool annotate =
false;
179 std::unique_ptr<dmlc::Stream> fi(
180 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
181 annotator.
Load(fi.get());
182 annotation = annotator.
Get();
185 LOG(INFO) <<
"Using branch annotation file `" 192 sequence.Reserve(model.
trees.size() + 3);
193 sequence.PushBack(
PlainBlock(group_policy.Accumulator()));
194 sequence.PushBack(
PlainBlock(QuantizePolicy::Preprocessing()));
196 LOG(INFO) <<
"Parallel compilation enabled; member trees will be " 198 <<
" translation units.";
200 for (
size_t unit_id = 0; unit_id < nunit; ++unit_id) {
202 group_policy.AccumulateTranslationUnit(unit_id)));
205 LOG(INFO) <<
"Parallel compilation disabled; all member trees will be " 206 <<
"dump to a single source file. This may increase " 207 <<
"compilation time and memory usage.";
208 for (
size_t tree_id = 0; tree_id < model.
trees.size(); ++tree_id) {
210 if (!annotation.empty()) {
211 sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
212 annotation[tree_id])));
214 sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
219 sequence.PushBack(
PlainBlock(group_policy.FinalReturn(model.
trees.size(),
222 FunctionBlock query_func(
"size_t get_num_output_group(void)",
223 PlainBlock(group_policy.GroupQueryFunction()),
224 &semantic_model.function_registry,
true);
228 &semantic_model.function_registry,
true);
229 FunctionBlock pred_transform_func(PredTransformPrototype(
false),
230 PlainBlock(PredTransformFunction(model,
false)),
231 &semantic_model.function_registry,
true);
232 FunctionBlock pred_transform_batch_func(PredTransformPrototype(
true),
233 PlainBlock(PredTransformFunction(model,
true)),
234 &semantic_model.function_registry,
true);
236 std::move(sequence), &semantic_model.function_registry,
true);
238 main_file.Reserve(5);
239 main_file.PushBack(std::move(query_func));
240 main_file.PushBack(std::move(query_func2));
241 main_file.PushBack(std::move(pred_transform_func));
242 main_file.PushBack(std::move(pred_transform_batch_func));
243 main_file.PushBack(std::move(main_func));
244 auto file_preamble = QuantizePolicy::ConstantsPreamble();
245 semantic_model.units.emplace_back(
PlainBlock(file_preamble),
246 std::move(main_file));
250 const size_t unit_size = (model.
trees.size() + nunit - 1) / nunit;
251 for (
size_t unit_id = 0; unit_id < nunit; ++unit_id) {
252 const size_t tree_begin = unit_id * unit_size;
253 const size_t tree_end = std::min((unit_id + 1) * unit_size,
256 if (tree_begin < tree_end) {
257 unit_seq.Reserve(tree_end - tree_begin + 2);
259 unit_seq.PushBack(
PlainBlock(group_policy.Accumulator()));
260 for (
size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
262 if (!annotation.empty()) {
263 unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
264 annotation[tree_id])));
266 unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
270 unit_seq.PushBack(
PlainBlock(group_policy.Return()));
271 FunctionBlock unit_func(group_policy.PrototypeTranslationUnit(unit_id),
273 &semantic_model.function_registry);
274 semantic_model.units.emplace_back(
PlainBlock(), std::move(unit_func));
277 std::vector<std::string> header{
"#include <stdlib.h>",
278 "#include <string.h>",
280 "#include <stdint.h>"};
281 common::TransformPushBack(&header, QuantizePolicy::CommonHeader(),
282 [] (std::string line) {
return line; });
284 header.emplace_back();
285 #if defined(__clang__) || defined(__GNUC__) 287 header.emplace_back(
"#define LIKELY(x) __builtin_expect(!!(x), 1)");
288 header.emplace_back(
"#define UNLIKELY(x) __builtin_expect(!!(x), 0)");
290 header.emplace_back(
"#define LIKELY(x) (x)");
291 header.emplace_back(
"#define UNLIKELY(x) (x)");
294 semantic_model.common_header
295 = std::move(common::make_unique<PlainBlock>(header));
296 return semantic_model;
301 GroupPolicy group_policy;
303 std::unique_ptr<CodeBlock> WalkTree(
const Tree& tree,
size_t tree_id,
304 const std::vector<size_t>& counts)
const {
305 return WalkTree_(tree, tree_id, counts, 0);
308 std::unique_ptr<CodeBlock> WalkTree_(
const Tree& tree,
size_t tree_id,
309 const std::vector<size_t>& counts,
314 return std::unique_ptr<CodeBlock>(
new PlainBlock(
315 group_policy.AccumulateLeaf(node, tree_id)));
317 BranchHint branch_hint = BranchHint::kNone;
318 if (!counts.empty()) {
319 const size_t left_count = counts[node.
cleft()];
320 const size_t right_count = counts[node.
cright()];
321 branch_hint = (left_count > right_count) ? BranchHint::kLikely
322 : BranchHint::kUnlikely;
324 std::unique_ptr<Condition> condition(
nullptr);
325 if (node.
split_type() == SplitFeatureType::kNumerical) {
326 condition = common::make_unique<NumericSplitCondition>(node,
327 QuantizePolicy::NumericAdapter());
329 condition = common::make_unique<CategoricalSplitCondition>(node);
332 common::MoveUniquePtr(condition),
333 common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.
cleft())),
334 common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.
cright())),
347 this->info = std::move(info);
361 template <
typename... Args>
362 void Init(Args&&... args) {
363 MetadataStore::Init(std::forward<Args>(args)...);
365 NumericSplitCondition::NumericAdapter NumericAdapter()
const {
367 std::ostringstream oss;
368 if (!std::isfinite(threshold)) {
371 oss << (semantic::CompareWithOp(0.0, op, threshold) ?
"1" :
"0");
374 const std::streamsize ss = std::cout.precision();
375 oss <<
"data[" << split_index <<
"].fvalue " 376 << semantic::OpName(op) <<
" " 377 << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
379 << std::setprecision(ss);
384 std::vector<std::string> CommonHeader()
const {
391 std::vector<std::string> ConstantsPreamble()
const {
394 std::vector<std::string> Preprocessing()
const {
397 bool QuantizeFlag()
const {
404 template <
typename... Args>
405 void Init(Args&&... args) {
406 MetadataStore::Init(std::forward<Args>(args)...);
408 std::string(
"for (int i = 0; i < ")
409 + std::to_string(GetInfo().num_feature) +
"; ++i) {",
410 " if (data[i].missing != -1 && !is_categorical[i]) {",
411 " data[i].qvalue = quantize(data[i].fvalue, i);",
415 NumericSplitCondition::NumericAdapter NumericAdapter()
const {
416 const std::vector<std::vector<tl_float>>& cut_pts = GetInfo().cut_pts;
417 return [&cut_pts] (
Operator op,
unsigned split_index,
419 std::ostringstream oss;
420 const auto& v = cut_pts[split_index];
421 if (!std::isfinite(threshold)) {
424 oss << (semantic::CompareWithOp(0.0, op, threshold) ?
"1" :
"0");
426 auto loc = common::binary_search(v.begin(), v.end(), threshold);
427 CHECK(loc != v.end());
428 oss <<
"data[" << split_index <<
"].qvalue " << semantic::OpName(op)
429 <<
" " <<
static_cast<size_t>(loc - v.begin()) * 2;
434 std::vector<std::string> CommonHeader()
const {
442 std::vector<std::string> ConstantsPreamble()
const {
443 std::vector<std::string> ret;
444 ret.emplace_back(
"static const unsigned char is_categorical[] = {");
446 std::ostringstream oss, oss2;
449 const int num_feature = GetInfo().num_feature;
450 const auto& is_categorical = GetInfo().is_categorical;
451 for (
int fid = 0; fid < num_feature; ++fid) {
452 if (is_categorical[fid]) {
453 common::WrapText(&oss, &length,
"1", 80);
455 common::WrapText(&oss, &length,
"0", 80);
458 ret.push_back(oss.str());
459 ret.emplace_back(
"};");
462 ret.emplace_back(
"static const float threshold[] = {");
464 std::ostringstream oss, oss2;
467 for (
const auto& e : GetInfo().cut_pts) {
468 for (
const auto& value : e) {
469 oss2.clear(); oss2.str(std::string()); oss2 << value;
470 common::WrapText(&oss, &length, oss2.str(), 80);
473 ret.push_back(oss.str());
474 ret.emplace_back(
"};");
477 ret.emplace_back(
"static const int th_begin[] = {");
479 std::ostringstream oss, oss2;
483 for (
const auto& e : GetInfo().cut_pts) {
484 oss2.clear(); oss2.str(std::string()); oss2 << accum;
485 common::WrapText(&oss, &length, oss2.str(), 80);
488 ret.push_back(oss.str());
489 ret.emplace_back(
"};");
492 ret.emplace_back(
"static const int th_len[] = {");
494 std::ostringstream oss, oss2;
497 for (
const auto& e : GetInfo().cut_pts) {
498 oss2.clear(); oss2.str(std::string()); oss2 << e.size();
499 common::WrapText(&oss, &length, oss2.str(), 80);
501 ret.push_back(oss.str());
502 ret.emplace_back(
"};");
507 "static inline int quantize(float val, unsigned fid)",
509 {
"const float* array = &threshold[th_begin[fid]];",
510 "int len = th_len[fid];",
515 "if (val < array[0]) {",
518 "while (low + 1 < high) {",
519 " mid = (low + high) / 2;",
520 " mval = array[mid];",
521 " if (val == mval) {",
523 " } else if (val < mval) {",
529 "if (array[low] == val) {",
531 "} else if (high == len) {",
534 " return low * 2 + 1;",
535 "}"}),
nullptr).Compile();
536 ret.insert(ret.end(), func.begin(), func.end());
539 std::vector<std::string> Preprocessing()
const {
540 return quant_preamble;
542 bool QuantizeFlag()
const {
546 std::vector<std::string> quant_preamble;
549 inline std::vector<std::vector<tl_float>>
550 ExtractCutPoints(
const Model& model) {
551 std::vector<std::vector<tl_float>> cut_pts;
553 std::vector<std::set<tl_float>> thresh_;
556 for (
size_t i = 0; i < model.
trees.size(); ++i) {
561 const int nid = Q.front();
565 if (node.
split_type() == SplitFeatureType::kNumerical) {
568 if (std::isfinite(threshold)) {
569 thresh_[split_index].insert(threshold);
572 CHECK(node.
split_type() == SplitFeatureType::kCategorical);
574 Q.push(node.
cleft());
580 std::copy(thresh_[i].begin(), thresh_[i].end(),
581 std::back_inserter(cut_pts[i]));
587 .describe(
"A compiler with a recursive approach")
590 return new RecursiveCompiler<Quantize>(param);
592 return new RecursiveCompiler<NoQuantize>(param);
608 GroupPolicy::GroupQueryFunction()
const {
609 return "return " + std::to_string(num_output_group) +
";";
613 GroupPolicy::Accumulator()
const {
614 if (num_output_group > 1) {
615 return std::string(
"float sum[") + std::to_string(num_output_group)
616 +
"] = {0.0f};\n unsigned int tmp;";
618 return "float sum = 0.0f;\n unsigned int tmp;";
623 GroupPolicy::AccumulateTranslationUnit(
size_t unit_id)
const {
624 if (num_output_group > 1) {
625 return std::string(
"predict_margin_multiclass_unit")
626 + std::to_string(unit_id) +
"(data, sum);";
628 return std::string(
"sum += predict_margin_unit")
629 + std::to_string(unit_id) +
"(data);";
633 inline std::vector<std::string>
635 size_t tree_id)
const {
636 if (num_output_group > 1) {
637 if (random_forest_flag) {
639 const std::vector<treelite::tl_float>& leaf_vector = node.
leaf_vector();
640 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group))
641 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
642 std::vector<std::string> lines;
643 lines.reserve(num_output_group);
644 for (
int group_id = 0; group_id < num_output_group; ++group_id) {
645 lines.push_back(std::string(
"sum[") + std::to_string(group_id)
647 + treelite::common::ToString(leaf_vector[group_id]) +
";");
653 return { std::string(
"sum[") + std::to_string(tree_id % num_output_group)
654 +
"] += (float)" + treelite::common::ToString(leaf_value) +
";" };
658 return {std::string(
"sum += (float)")
659 + treelite::common::ToString(leaf_value) +
";" };
663 inline std::vector<std::string>
664 GroupPolicy::Return()
const {
665 if (num_output_group > 1) {
666 return {std::string(
"for (int i = 0; i < ")
667 + std::to_string(num_output_group) +
"; ++i) {",
668 " result[i] += sum[i];",
671 return {
"return sum;" };
675 inline std::vector<std::string>
676 GroupPolicy::FinalReturn(
size_t num_tree,
float global_bias)
const {
677 if (num_output_group > 1) {
678 if (random_forest_flag) {
680 return {std::string(
"for (int i = 0; i < ")
681 + std::to_string(num_output_group) +
"; ++i) {",
682 std::string(
" result[i] = sum[i] / ")
683 + std::to_string(num_tree) +
" + (" 684 + treelite::common::ToString(global_bias) +
");",
688 return {std::string(
"for (int i = 0; i < ")
689 + std::to_string(num_output_group) +
"; ++i) {",
690 " result[i] = sum[i] + (" 691 + treelite::common::ToString(global_bias) +
");",
695 if (random_forest_flag) {
696 return { std::string(
"return sum / ") + std::to_string(num_tree) +
" + (" 697 + treelite::common::ToString(global_bias) +
");" };
699 return { std::string(
"return sum + (")
700 + treelite::common::ToString(global_bias) +
");" };
706 GroupPolicy::Prototype()
const {
707 if (num_output_group > 1) {
708 return "void predict_margin_multiclass(union Entry* data, float* result)";
710 return "float predict_margin(union Entry* data)";
715 GroupPolicy::PrototypeTranslationUnit(
size_t unit_id)
const {
716 if (num_output_group > 1) {
717 return std::string(
"void predict_margin_multiclass_unit")
718 + std::to_string(unit_id) +
"(union Entry* data, float* result)";
720 return std::string(
"float predict_margin_unit")
721 + std::to_string(unit_id) +
"(union Entry* data)";
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
plain code block containing one or more lines of code
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
fundamental block in semantic model. All code blocks should inherit from this class.
std::vector< Tree > trees
member trees
parameters for tree compiler
unsigned split_index() const
feature index of split condition
ModelParam param
extra parameters
SemanticModel Compile(const Model &model) override
convert tree ensemble model into semantic model
in-memory representation of a decision tree
float global_bias
global bias of the model
Parameters for tree compiler.
BranchHint
enum class to store branch annotation
tl_float threshold() const
Interface of compiler that translates a tree ensemble model into a semantic model.
const std::vector< tl_float > & leaf_vector() const
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
int cright() const
index of right child
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
function block with a prototype and code body. Its prototype can optionally be registered with a func...
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
if-else statement with condition may store a branch hint (>50% or <50% likely)
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
sequence of one or more code blocks
tl_float leaf_value() const
int cleft() const
index of left child
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
int quantize
whether to quantize threshold points (0: no, >0: yes)
Building blocks for semantic model of tree prediction code.
translation unit is abstraction of a source file
SplitFeatureType split_type() const
get feature split type
semantic model consists of a header, function registry, and a list of translation units ...
int verbose
if >0, produce extra messages
bool is_leaf() const
whether current node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators