13 #include <dmlc/registry.h> 29 const NumericAdapter& numeric_adapter)
30 : split_index(node.split_index()), default_left(node.default_left()),
31 op(node.comparison_op()), threshold(node.threshold()),
32 numeric_adapter(numeric_adapter) {}
34 NumericAdapter&& numeric_adapter)
35 : split_index(node.split_index()), default_left(node.default_left()),
36 op(node.comparison_op()), threshold(node.threshold()),
37 numeric_adapter(std::move(numeric_adapter)) {}
39 inline std::
string Compile()
const override {
40 const std::string bitmap
41 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
42 return ((default_left) ? (std::string(
"!(") + bitmap +
") || ")
43 : (std::string(
" (") + bitmap +
") && "))
44 + numeric_adapter(op, split_index, threshold);
52 NumericAdapter numeric_adapter;
58 : split_index(node.split_index()), default_left(node.default_left()),
59 categorical_bitmap(to_bitmap(node.left_categories())) {}
61 inline std::
string Compile()
const override {
62 const std::string bitmap
63 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
64 const std::string comp
65 = std::string(
"((") + std::to_string(categorical_bitmap)
66 +
"U >> (unsigned int)(data[" + std::to_string(split_index)
68 return ((default_left) ? (std::string(
"!(") + bitmap +
") || ")
69 : (std::string(
" (") + bitmap +
") && "))
70 + ((categorical_bitmap == 0) ? std::string(
"0") : comp);
76 uint64_t categorical_bitmap;
78 inline uint64_t to_bitmap(
const std::vector<uint8_t>& left_categories)
const {
80 for (uint8_t e : left_categories) {
81 CHECK_LT(e, 64) <<
"Cannot have more than 64 categories in a feature";
82 result |= (
static_cast<uint64_t
>(1) << e);
91 std::string GroupQueryFunction()
const;
92 std::string Accumulator()
const;
93 std::string AccumulateTranslationUnit(
size_t unit_id)
const;
95 size_t tree_id)
const;
96 std::vector<std::string> Return()
const;
97 std::vector<std::string> FinalReturn(
size_t num_tree,
float global_bias)
const;
98 std::string Prototype()
const;
99 std::string PrototypeTranslationUnit(
size_t unit_id)
const;
101 int num_output_group;
102 bool random_forest_flag;
110 DMLC_REGISTRY_FILE_TAG(recursive);
112 std::vector<std::vector<tl_float>> ExtractCutPoints(
const Model& model);
116 std::vector<std::vector<tl_float>> cut_pts;
117 std::vector<bool> is_categorical;
119 inline void Init(
const Model& model,
bool extract_cut_pts =
false) {
121 is_categorical.clear();
122 is_categorical.resize(num_feature,
false);
125 is_categorical[e] =
true;
128 if (extract_cut_pts) {
129 cut_pts = std::move(ExtractCutPoints(model));
134 template <
typename QuantizePolicy>
140 LOG(INFO) <<
"Using RecursiveCompiler";
155 info.Init(model, QuantizePolicy::QuantizeFlag());
156 QuantizePolicy::Init(std::move(info));
157 group_policy.Init(model);
159 std::vector<std::vector<size_t>> annotation;
160 bool annotate =
false;
163 std::unique_ptr<dmlc::Stream> fi(
164 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
165 annotator.
Load(fi.get());
166 annotation = annotator.
Get();
169 LOG(INFO) <<
"Using branch annotation file `" 176 sequence.Reserve(model.
trees.size() + 3);
177 sequence.PushBack(
PlainBlock(group_policy.Accumulator()));
178 sequence.PushBack(
PlainBlock(QuantizePolicy::Preprocessing()));
180 LOG(INFO) <<
"Parallel compilation enabled; member trees will be " 182 <<
" translation units.";
184 for (
size_t unit_id = 0; unit_id < nunit; ++unit_id) {
186 group_policy.AccumulateTranslationUnit(unit_id)));
189 LOG(INFO) <<
"Parallel compilation disabled; all member trees will be " 190 <<
"dump to a single source file. This may increase " 191 <<
"compilation time and memory usage.";
192 for (
size_t tree_id = 0; tree_id < model.
trees.size(); ++tree_id) {
194 if (!annotation.empty()) {
195 sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
196 annotation[tree_id])));
198 sequence.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
203 sequence.PushBack(
PlainBlock(group_policy.FinalReturn(model.
trees.size(),
206 FunctionBlock query_func(
"size_t get_num_output_group(void)",
207 PlainBlock(group_policy.GroupQueryFunction()),
208 &semantic_model.function_registry,
true);
212 &semantic_model.function_registry,
true);
213 FunctionBlock pred_transform_func(PredTransformPrototype(
false),
214 PlainBlock(PredTransformFunction(model,
false)),
215 &semantic_model.function_registry,
true);
216 FunctionBlock pred_transform_batch_func(PredTransformPrototype(
true),
217 PlainBlock(PredTransformFunction(model,
true)),
218 &semantic_model.function_registry,
true);
220 std::move(sequence), &semantic_model.function_registry,
true);
222 main_file.Reserve(5);
223 main_file.PushBack(std::move(query_func));
224 main_file.PushBack(std::move(query_func2));
225 main_file.PushBack(std::move(pred_transform_func));
226 main_file.PushBack(std::move(pred_transform_batch_func));
227 main_file.PushBack(std::move(main_func));
228 auto file_preamble = QuantizePolicy::ConstantsPreamble();
229 semantic_model.units.emplace_back(
PlainBlock(file_preamble),
230 std::move(main_file));
234 const size_t unit_size = (model.
trees.size() + nunit - 1) / nunit;
235 for (
size_t unit_id = 0; unit_id < nunit; ++unit_id) {
236 const size_t tree_begin = unit_id * unit_size;
237 const size_t tree_end = std::min((unit_id + 1) * unit_size,
240 unit_seq.Reserve(tree_end - tree_begin + 2);
241 unit_seq.PushBack(
PlainBlock(group_policy.Accumulator()));
242 for (
size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
244 if (!annotation.empty()) {
245 unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
246 annotation[tree_id])));
248 unit_seq.PushBack(common::MoveUniquePtr(WalkTree(tree, tree_id,
252 unit_seq.PushBack(
PlainBlock(group_policy.Return()));
253 FunctionBlock unit_func(group_policy.PrototypeTranslationUnit(unit_id),
255 &semantic_model.function_registry);
256 semantic_model.units.emplace_back(
PlainBlock(), std::move(unit_func));
259 auto header = QuantizePolicy::CommonHeader();
261 header.emplace_back();
262 #if defined(__clang__) || defined(__GNUC__) 264 header.emplace_back(
"#define LIKELY(x) __builtin_expect(!!(x), 1)");
265 header.emplace_back(
"#define UNLIKELY(x) __builtin_expect(!!(x), 0)");
267 header.emplace_back(
"#define LIKELY(x) (x)");
268 header.emplace_back(
"#define UNLIKELY(x) (x)");
271 semantic_model.common_header
272 = std::move(common::make_unique<PlainBlock>(header));
273 return semantic_model;
278 GroupPolicy group_policy;
280 std::unique_ptr<CodeBlock> WalkTree(
const Tree& tree,
size_t tree_id,
281 const std::vector<size_t>& counts)
const {
282 return WalkTree_(tree, tree_id, counts, 0);
285 std::unique_ptr<CodeBlock> WalkTree_(
const Tree& tree,
size_t tree_id,
286 const std::vector<size_t>& counts,
291 return std::unique_ptr<CodeBlock>(
new PlainBlock(
292 group_policy.AccumulateLeaf(node, tree_id)));
294 BranchHint branch_hint = BranchHint::kNone;
295 if (!counts.empty()) {
296 const size_t left_count = counts[node.
cleft()];
297 const size_t right_count = counts[node.
cright()];
298 branch_hint = (left_count > right_count) ? BranchHint::kLikely
299 : BranchHint::kUnlikely;
301 std::unique_ptr<Condition> condition(
nullptr);
302 if (node.
split_type() == SplitFeatureType::kNumerical) {
303 condition = common::make_unique<NumericSplitCondition>(node,
304 QuantizePolicy::NumericAdapter());
306 condition = common::make_unique<CategoricalSplitCondition>(node);
309 common::MoveUniquePtr(condition),
310 common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.
cleft())),
311 common::MoveUniquePtr(WalkTree_(tree, tree_id, counts, node.
cright())),
324 this->info = std::move(info);
338 template <
typename... Args>
339 void Init(Args&&... args) {
340 MetadataStore::Init(std::forward<Args>(args)...);
342 NumericSplitCondition::NumericAdapter NumericAdapter()
const {
344 std::ostringstream oss;
345 if (!std::isfinite(threshold)) {
348 oss << (semantic::CompareWithOp(0.0, op, threshold) ?
"1" :
"0");
350 oss <<
"data[" << split_index <<
"].fvalue " 351 << semantic::OpName(op) <<
" " << threshold;
356 std::vector<std::string> CommonHeader()
const {
357 return {
"#include <stdlib.h>",
358 "#include <string.h>",
360 "#include <stdint.h>",
367 std::vector<std::string> ConstantsPreamble()
const {
370 std::vector<std::string> Preprocessing()
const {
373 bool QuantizeFlag()
const {
380 template <
typename... Args>
381 void Init(Args&&... args) {
382 MetadataStore::Init(std::forward<Args>(args)...);
384 std::string(
"for (int i = 0; i < ")
385 + std::to_string(GetInfo().num_feature) +
"; ++i) {",
386 " if (data[i].missing != -1 && !is_categorical[i]) {",
387 " data[i].qvalue = quantize(data[i].fvalue, i);",
391 NumericSplitCondition::NumericAdapter NumericAdapter()
const {
392 const std::vector<std::vector<tl_float>>& cut_pts = GetInfo().cut_pts;
393 return [&cut_pts] (
Operator op,
unsigned split_index,
395 std::ostringstream oss;
396 const auto& v = cut_pts[split_index];
397 if (!std::isfinite(threshold)) {
400 oss << (semantic::CompareWithOp(0.0, op, threshold) ?
"1" :
"0");
402 auto loc = common::binary_search(v.begin(), v.end(), threshold);
403 CHECK(loc != v.end());
404 oss <<
"data[" << split_index <<
"].qvalue " << semantic::OpName(op)
405 <<
" " <<
static_cast<size_t>(loc - v.begin()) * 2;
410 std::vector<std::string> CommonHeader()
const {
411 return {
"#include <stdlib.h>",
412 "#include <string.h>",
414 "#include <stdint.h>",
422 std::vector<std::string> ConstantsPreamble()
const {
423 std::vector<std::string> ret;
424 ret.emplace_back(
"static const unsigned char is_categorical[] = {");
426 std::ostringstream oss, oss2;
429 const int num_feature = GetInfo().num_feature;
430 const auto& is_categorical = GetInfo().is_categorical;
431 for (
int fid = 0; fid < num_feature; ++fid) {
432 if (is_categorical[fid]) {
433 common::WrapText(&oss, &length,
"1", 80);
435 common::WrapText(&oss, &length,
"0", 80);
438 ret.push_back(oss.str());
439 ret.emplace_back(
"};");
442 ret.emplace_back(
"static const float threshold[] = {");
444 std::ostringstream oss, oss2;
447 for (
const auto& e : GetInfo().cut_pts) {
448 for (
const auto& value : e) {
449 oss2.clear(); oss2.str(std::string()); oss2 << value;
450 common::WrapText(&oss, &length, oss2.str(), 80);
453 ret.push_back(oss.str());
454 ret.emplace_back(
"};");
457 ret.emplace_back(
"static const int th_begin[] = {");
459 std::ostringstream oss, oss2;
463 for (
const auto& e : GetInfo().cut_pts) {
464 oss2.clear(); oss2.str(std::string()); oss2 << accum;
465 common::WrapText(&oss, &length, oss2.str(), 80);
468 ret.push_back(oss.str());
469 ret.emplace_back(
"};");
472 ret.emplace_back(
"static const int th_len[] = {");
474 std::ostringstream oss, oss2;
477 for (
const auto& e : GetInfo().cut_pts) {
478 oss2.clear(); oss2.str(std::string()); oss2 << e.size();
479 common::WrapText(&oss, &length, oss2.str(), 80);
481 ret.push_back(oss.str());
482 ret.emplace_back(
"};");
487 "static inline int quantize(float val, unsigned fid)",
489 {
"const float* array = &threshold[th_begin[fid]];",
490 "int len = th_len[fid];",
495 "if (val < array[0]) {",
498 "while (low + 1 < high) {",
499 " mid = (low + high) / 2;",
500 " mval = array[mid];",
501 " if (val == mval) {",
503 " } else if (val < mval) {",
509 "if (array[low] == val) {",
511 "} else if (high == len) {",
514 " return low * 2 + 1;",
515 "}"}),
nullptr).Compile();
516 ret.insert(ret.end(), func.begin(), func.end());
519 std::vector<std::string> Preprocessing()
const {
520 return quant_preamble;
522 bool QuantizeFlag()
const {
526 std::vector<std::string> quant_preamble;
529 inline std::vector<std::vector<tl_float>>
530 ExtractCutPoints(
const Model& model) {
531 std::vector<std::vector<tl_float>> cut_pts;
533 std::vector<std::set<tl_float>> thresh_;
536 for (
size_t i = 0; i < model.
trees.size(); ++i) {
541 const int nid = Q.front();
545 if (node.
split_type() == SplitFeatureType::kNumerical) {
548 if (std::isfinite(threshold)) {
549 thresh_[split_index].insert(threshold);
552 CHECK(node.
split_type() == SplitFeatureType::kCategorical);
554 Q.push(node.
cleft());
560 std::copy(thresh_[i].begin(), thresh_[i].end(),
561 std::back_inserter(cut_pts[i]));
567 .describe(
"A compiler with a recursive approach")
570 return new RecursiveCompiler<Quantize>(param);
572 return new RecursiveCompiler<NoQuantize>(param);
588 GroupPolicy::GroupQueryFunction()
const {
589 return "return " + std::to_string(num_output_group) +
";";
593 GroupPolicy::Accumulator()
const {
594 if (num_output_group > 1) {
595 return std::string(
"float sum[") + std::to_string(num_output_group)
598 return "float sum = 0.0f;";
603 GroupPolicy::AccumulateTranslationUnit(
size_t unit_id)
const {
604 if (num_output_group > 1) {
605 return std::string(
"predict_margin_multiclass_unit")
606 + std::to_string(unit_id) +
"(data, sum);";
608 return std::string(
"sum += predict_margin_unit")
609 + std::to_string(unit_id) +
"(data);";
613 inline std::vector<std::string>
615 size_t tree_id)
const {
616 if (num_output_group > 1) {
617 if (random_forest_flag) {
619 const std::vector<treelite::tl_float>& leaf_vector = node.
leaf_vector();
620 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group))
621 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
622 std::vector<std::string> lines;
623 lines.reserve(num_output_group);
624 for (
int group_id = 0; group_id < num_output_group; ++group_id) {
625 lines.push_back(std::string(
"sum[") + std::to_string(group_id)
627 + treelite::common::ToString(leaf_vector[group_id]) +
";");
633 return { std::string(
"sum[") + std::to_string(tree_id % num_output_group)
634 +
"] += (float)" + treelite::common::ToString(leaf_value) +
";" };
638 return {std::string(
"sum += (float)")
639 + treelite::common::ToString(leaf_value) +
";" };
643 inline std::vector<std::string>
644 GroupPolicy::Return()
const {
645 if (num_output_group > 1) {
646 return {std::string(
"for (int i = 0; i < ")
647 + std::to_string(num_output_group) +
"; ++i) {",
648 " result[i] += sum[i];",
651 return {
"return sum;" };
655 inline std::vector<std::string>
656 GroupPolicy::FinalReturn(
size_t num_tree,
float global_bias)
const {
657 if (num_output_group > 1) {
658 if (random_forest_flag) {
660 return {std::string(
"for (int i = 0; i < ")
661 + std::to_string(num_output_group) +
"; ++i) {",
662 std::string(
" result[i] = sum[i] / ")
663 + std::to_string(num_tree) +
" + (" 664 + treelite::common::ToString(global_bias) +
");",
668 return {std::string(
"for (int i = 0; i < ")
669 + std::to_string(num_output_group) +
"; ++i) {",
670 " result[i] = sum[i] + (" 671 + treelite::common::ToString(global_bias) +
");",
675 if (random_forest_flag) {
676 return { std::string(
"return sum / ") + std::to_string(num_tree) +
" + (" 677 + treelite::common::ToString(global_bias) +
");" };
679 return { std::string(
"return sum + (")
680 + treelite::common::ToString(global_bias) +
");" };
686 GroupPolicy::Prototype()
const {
687 if (num_output_group > 1) {
688 return "void predict_margin_multiclass(union Entry* data, float* result)";
690 return "float predict_margin(union Entry* data)";
695 GroupPolicy::PrototypeTranslationUnit(
size_t unit_id)
const {
696 if (num_output_group > 1) {
697 return std::string(
"void predict_margin_multiclass_unit")
698 + std::to_string(unit_id) +
"(union Entry* data, float* result)";
700 return std::string(
"float predict_margin_unit")
701 + std::to_string(unit_id) +
"(union Entry* data)";
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
plain code block containing one or more lines of code
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
fundamental block in semantic model. All code blocks should inherit from this class.
std::vector< Tree > trees
member trees
parameters for tree compiler
unsigned split_index() const
feature index of split condition
ModelParam param
extra parameters
SemanticModel Compile(const Model &model) override
convert tree ensemble model into semantic model
in-memory representation of a decision tree
float global_bias
global bias of the model
Parameters for tree compiler.
BranchHint
enum class to store branch annotation
tl_float threshold() const
Interface of compiler that translates a tree ensemble model into a semantic model.
const std::vector< tl_float > & leaf_vector() const
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
int cright() const
index of right child
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
function block with a prototype and code body. Its prototype can optionally be registered with a func...
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
if-else statement with condition may store a branch hint (>50% or <50% likely)
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
#define CLONEABLE_BOILERPLATE(className)
macro to define boilerplate code for Cloneable classes
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
sequence of one or more code blocks
tl_float leaf_value() const
int cleft() const
index of left child
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
int quantize
whether to quantize threshold points (0: no, >0: yes)
Building blocks for semantic model of tree prediction code.
translation unit is abstraction of a source file
SplitFeatureType split_type() const
get feature split type
semantic model consists of a header, function registry, and a list of translation units ...
int verbose
if >0, produce extra messages
bool is_leaf() const
whether current node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators