10 #include <fmt/format.h> 11 #include <rapidjson/stringbuffer.h> 12 #include <rapidjson/writer.h> 15 #include <unordered_map> 21 #include "./pred_transform.h" 31 #if defined(_MSC_VER) || defined(_WIN32) 32 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 34 #define DLLEXPORT_KEYWORD "" 46 template <
typename ThresholdType,
typename LeafOutputType>
49 cm.backend =
"native";
51 TREELITE_CHECK(model.
task_type != TaskType::kMultiClfCategLeaf)
52 <<
"Model task type unsupported by ASTNativeCompiler";
53 TREELITE_CHECK(model.
task_param.output_type == TaskParam::OutputType::kFloat)
54 <<
"ASTNativeCompiler only supports models with float output";
59 pred_transform_ = model.
param.pred_transform;
60 sigmoid_alpha_ = model.
param.sigmoid_alpha;
61 global_bias_ = model.
param.global_bias;
65 builder.BuildAST(model);
66 if (builder.FoldCode(param_.code_folding_req) || param_.quantize > 0) {
69 = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
71 if (param_.annotate_in !=
"NULL") {
73 std::ifstream fi(param_.annotate_in.c_str());
75 const auto annotation = annotator.
Get();
76 builder.LoadDataCounts(annotation);
77 TREELITE_LOG(INFO) <<
"Loading node frequencies from `" 78 << param_.annotate_in <<
"'";
80 builder.Split(param_.parallel_comp);
81 if (param_.quantize > 0) {
82 builder.QuantizeThresholds();
86 const char* destfile = getenv(
"TREELITE_DUMP_AST");
88 std::ofstream os(destfile);
89 os << builder.GetDump() << std::endl;
93 WalkAST<ThresholdType, LeafOutputType>(builder.GetRootNode(),
"main.c", 0);
94 if (files_.count(
"arrays.c") > 0) {
95 PrependToBuffer(
"arrays.c",
"#include \"header.h\"\n", 0);
100 rapidjson::StringBuffer os;
101 rapidjson::Writer<rapidjson::StringBuffer> writer(os);
103 writer.StartObject();
104 writer.Key(
"target");
105 writer.String(param_.native_lib_name.data(), param_.native_lib_name.size());
106 writer.Key(
"sources");
108 for (
const auto& kv : files_) {
109 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
110 const size_t line_count
111 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
112 writer.StartObject();
114 std::string name = kv.first.substr(0, kv.first.length() - 2);
115 writer.String(name.data(), name.size());
116 writer.Key(
"length");
117 writer.Uint64(line_count);
126 cm.files = std::move(files_);
131 TREELITE_CHECK(model.GetLeafOutputType() != TypeInfo::kUInt32)
132 <<
"Integer leaf outputs not yet supported";
133 this->pred_tranform_func_ = PredTransformFunction(
"native", model);
134 return model.Dispatch([
this](
const auto& model_handle) {
135 return this->CompileImpl(model_handle);
148 std::string pred_transform_;
149 float sigmoid_alpha_;
151 std::string pred_tranform_func_;
152 std::string array_is_categorical_;
153 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
155 template <
typename ThresholdType,
typename LeafOutputType>
156 void WalkAST(
const ASTNode* node,
157 const std::string& dest,
166 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
167 HandleMainNode<ThresholdType, LeafOutputType>(t1, dest, indent);
168 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
169 HandleACNode<ThresholdType, LeafOutputType>(t2, dest, indent);
170 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
171 HandleCondNode<ThresholdType, LeafOutputType>(t3, dest, indent);
173 HandleOutputNode<ThresholdType, LeafOutputType>(t4, dest, indent);
174 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
175 HandleTUNode<ThresholdType, LeafOutputType>(t5, dest, indent);
177 HandleQNode<ThresholdType, LeafOutputType>(t6, dest, indent);
178 }
else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
179 HandleCodeFolderNode<ThresholdType, LeafOutputType>(t7, dest, indent);
181 TREELITE_LOG(FATAL) <<
"Unrecognized AST node type";
186 inline void AppendToBuffer(
const std::string& dest,
187 const std::string& content,
189 files_[dest].content += common_util::IndentMultiLineString(content, indent);
193 inline void PrependToBuffer(
const std::string& dest,
194 const std::string& content,
197 = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
200 template <
typename ThresholdType,
typename LeafOutputType>
201 void HandleMainNode(
const MainNode* node,
202 const std::string& dest,
204 const std::string threshold_type
205 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
206 const std::string leaf_output_type
207 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
208 const std::string predict_function_signature
210 fmt::format(
"size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)",
212 : fmt::format(
"{} predict(union Entry* data, int pred_margin)",
215 if (!array_is_categorical_.empty()) {
216 array_is_categorical_
217 = fmt::format(
"const unsigned char is_categorical[] = {{\n{}\n}}",
218 array_is_categorical_);
221 const std::string query_functions_definition
222 = fmt::format(native::query_functions_definition_template,
224 "num_feature"_a = num_feature_,
225 "pred_transform"_a = pred_transform_,
226 "sigmoid_alpha"_a = sigmoid_alpha_,
227 "global_bias"_a = global_bias_,
232 fmt::format(native::main_start_template,
233 "array_is_categorical"_a = array_is_categorical_,
234 "query_functions_definition"_a = query_functions_definition,
235 "pred_transform_function"_a = pred_tranform_func_,
236 "predict_function_signature"_a = predict_function_signature),
238 const std::string query_functions_prototype
239 = fmt::format(native::query_functions_prototype_template,
240 "dllexport"_a = DLLEXPORT_KEYWORD);
241 AppendToBuffer(
"header.h",
242 fmt::format(native::header_template,
243 "dllexport"_a = DLLEXPORT_KEYWORD,
244 "predict_function_signature"_a = predict_function_signature,
245 "query_functions_prototype"_a = query_functions_prototype,
246 "threshold_type"_a = threshold_type,
247 "threshold_type_Node"_a = (param_.
quantize > 0 ? std::string(
"int") : threshold_type)),
250 TREELITE_CHECK_EQ(node->children.size(), 1);
251 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
253 std::string optional_average_field;
254 if (node->average_result) {
255 if (task_type_ == TaskType::kMultiClfGrovePerClass) {
258 TREELITE_CHECK_GT(task_param_.
num_class, 1);
259 TREELITE_CHECK_EQ(node->num_tree % task_param_.
num_class, 0)
260 <<
"Expected the number of trees to be divisible by the number of classes";
261 int num_boosting_round = node->num_tree /
static_cast<int>(task_param_.
num_class);
262 optional_average_field = fmt::format(
" / {}", num_boosting_round);
264 TREELITE_CHECK(task_type_ == TaskType::kBinaryClfRegr
265 || task_type_ == TaskType::kMultiClfProbDistLeaf);
268 optional_average_field = fmt::format(
" / {}", node->num_tree);
273 fmt::format(native::main_end_multiclass_template,
275 "optional_average_field"_a = optional_average_field,
276 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
277 "leaf_output_type"_a = leaf_output_type),
281 fmt::format(native::main_end_template,
282 "optional_average_field"_a = optional_average_field,
283 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
284 "leaf_output_type"_a = leaf_output_type),
289 template <
typename ThresholdType,
typename LeafOutputType>
291 const std::string& dest,
293 const std::string leaf_output_type
294 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
297 fmt::format(
"{leaf_output_type} sum[{num_class}] = {{0}};\n" 298 "unsigned int tmp;\n" 299 "int nid, cond, fid; /* used for folded subtrees */\n",
301 "leaf_output_type"_a = leaf_output_type), indent);
304 fmt::format(
"{leaf_output_type} sum = ({leaf_output_type})0;\n" 305 "unsigned int tmp;\n" 306 "int nid, cond, fid; /* used for folded subtrees */\n",
307 "leaf_output_type"_a = leaf_output_type),
310 for (
ASTNode* child : node->children) {
311 WalkAST<ThresholdType, LeafOutputType>(child, dest, indent);
315 template <
typename ThresholdType,
typename LeafOutputType>
317 const std::string& dest,
320 std::string condition_with_na_check;
323 std::string condition = ExtractNumericalCondition(t);
324 const char* condition_with_na_check_template
325 = (node->default_left) ?
326 "!(data[{split_index}].missing != -1) || ({condition})" 327 :
" (data[{split_index}].missing != -1) && ({condition})";
328 condition_with_na_check
329 = fmt::format(condition_with_na_check_template,
330 "split_index"_a = node->split_index,
331 "condition"_a = condition);
335 condition_with_na_check = ExtractCategoricalCondition(t2);
337 if (node->children[0]->data_count && node->children[1]->data_count) {
338 const uint64_t left_freq = *node->children[0]->data_count;
339 const uint64_t right_freq = *node->children[1]->data_count;
340 condition_with_na_check
341 = fmt::format(
" {keyword}( {condition} ) ",
342 "keyword"_a = ((left_freq > right_freq) ?
"LIKELY" :
"UNLIKELY"),
343 "condition"_a = condition_with_na_check);
346 fmt::format(
"if ({}) {{\n", condition_with_na_check), indent);
347 TREELITE_CHECK_EQ(node->children.size(), 2);
348 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
349 AppendToBuffer(dest,
"} else {\n", indent);
350 WalkAST<ThresholdType, LeafOutputType>(node->children[1], dest, indent + 2);
351 AppendToBuffer(dest,
"}\n", indent);
354 template <
typename ThresholdType,
typename LeafOutputType>
356 const std::string& dest,
358 AppendToBuffer(dest, RenderOutputStatement(node), indent);
359 TREELITE_CHECK_EQ(node->children.size(), 0);
362 template <
typename ThresholdType,
typename LeafOutputType>
364 const std::string& dest,
366 const int unit_id = node->unit_id;
367 const std::string new_file = fmt::format(
"tu{}.c", unit_id);
368 const std::string leaf_output_type
369 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
371 std::string unit_function_name, unit_function_signature,
372 unit_function_call_signature;
375 = fmt::format(
"predict_margin_multiclass_unit{}", unit_id);
376 unit_function_signature
377 = fmt::format(
"void {function_name}(union Entry* data, {leaf_output_type}* result)",
378 "function_name"_a = unit_function_name,
379 "leaf_output_type"_a = leaf_output_type);
380 unit_function_call_signature
381 = fmt::format(
"{}(data, sum);\n", unit_function_name);
384 = fmt::format(
"predict_margin_unit{}", unit_id);
385 unit_function_signature
386 = fmt::format(
"{leaf_output_type} {function_name}(union Entry* data)",
387 "function_name"_a = unit_function_name,
388 "leaf_output_type"_a = leaf_output_type);
389 unit_function_call_signature
390 = fmt::format(
"sum += {}(data);\n", unit_function_name);
392 AppendToBuffer(dest, unit_function_call_signature, indent);
393 AppendToBuffer(new_file,
394 fmt::format(
"#include \"header.h\"\n" 395 "{} {{\n", unit_function_signature), 0);
396 TREELITE_CHECK_EQ(node->children.size(), 1);
397 WalkAST<ThresholdType, LeafOutputType>(node->children[0], new_file, 2);
399 AppendToBuffer(new_file,
400 fmt::format(
" for (int i = 0; i < {num_class}; ++i) {{\n" 401 " result[i] += sum[i];\n" 404 "num_class"_a = task_param_.
num_class), 0);
406 AppendToBuffer(new_file,
" return sum;\n}\n", 0);
408 AppendToBuffer(
"header.h", fmt::format(
"{};\n", unit_function_signature), 0);
411 template <
typename ThresholdType,
typename LeafOutputType>
413 const std::string& dest,
415 const std::string threshold_type
416 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
418 std::string array_threshold, array_th_begin, array_th_len;
423 size_t total_num_threshold;
427 for (
const auto& e : node->cut_pts) {
434 array_threshold = formatter.
str();
439 for (
const auto& e : node->cut_pts) {
443 total_num_threshold = accum;
444 array_th_begin = formatter.
str();
448 for (
const auto& e : node->cut_pts) {
449 formatter << e.size();
451 array_th_len = formatter.
str();
453 if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
454 PrependToBuffer(dest,
455 fmt::format(native::qnode_template,
456 "total_num_threshold"_a = total_num_threshold,
457 "threshold_type"_a = threshold_type),
460 fmt::format(native::quantize_loop_template,
461 "num_feature"_a = num_feature_), indent);
463 if (!array_threshold.empty()) {
464 PrependToBuffer(dest,
465 fmt::format(
"static const {threshold_type} threshold[] = {{\n" 466 "{array_threshold}\n" 468 "array_threshold"_a = array_threshold,
469 "threshold_type"_a = threshold_type),
472 if (!array_th_begin.empty()) {
473 PrependToBuffer(dest,
474 fmt::format(
"static const int th_begin[] = {{\n" 476 "}};\n",
"array_th_begin"_a = array_th_begin), 0);
478 if (!array_th_len.empty()) {
479 PrependToBuffer(dest,
480 fmt::format(
"static const int th_len[] = {{\n" 482 "}};\n",
"array_th_len"_a = array_th_len), 0);
484 TREELITE_CHECK_EQ(node->children.size(), 1);
485 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent);
488 template <
typename ThresholdType,
typename LeafOutputType>
490 const std::string& dest,
492 TREELITE_CHECK_EQ(node->children.size(), 1);
493 const int node_id = node->children[0]->node_id;
494 const int tree_id = node->children[0]->tree_id;
497 std::string array_nodes, array_cat_bitmap, array_cat_begin;
499 const std::string node_array_name
500 = fmt::format(
"node_tree{}_node{}", tree_id, node_id);
504 const std::string cat_bitmap_name
505 = fmt::format(
"cat_bitmap_tree{}_node{}", tree_id, node_id);
509 const std::string cat_begin_name
510 = fmt::format(
"cat_begin_tree{}_node{}", tree_id, node_id);
512 std::string output_switch_statement;
514 common_util::RenderCodeFolderArrays<ThresholdType, LeafOutputType>(node, param_.
quantize,
515 false,
"{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
517 &array_nodes, &array_cat_bitmap, &array_cat_begin, &output_switch_statement,
519 if (!array_nodes.empty()) {
520 AppendToBuffer(
"header.h",
521 fmt::format(
"extern const struct Node {node_array_name}[];\n",
522 "node_array_name"_a = node_array_name), 0);
523 AppendToBuffer(
"arrays.c",
524 fmt::format(
"const struct Node {node_array_name}[] = {{\n" 527 "node_array_name"_a = node_array_name,
528 "array_nodes"_a = array_nodes), 0);
531 if (!array_cat_bitmap.empty()) {
532 AppendToBuffer(
"header.h",
533 fmt::format(
"extern const uint64_t {cat_bitmap_name}[];\n",
534 "cat_bitmap_name"_a = cat_bitmap_name), 0);
535 AppendToBuffer(
"arrays.c",
536 fmt::format(
"const uint64_t {cat_bitmap_name}[] = {{\n" 537 "{array_cat_bitmap}\n" 539 "cat_bitmap_name"_a = cat_bitmap_name,
540 "array_cat_bitmap"_a = array_cat_bitmap), 0);
543 if (!array_cat_begin.empty()) {
544 AppendToBuffer(
"header.h",
545 fmt::format(
"extern const size_t {cat_begin_name}[];\n",
546 "cat_begin_name"_a = cat_begin_name), 0);
547 AppendToBuffer(
"arrays.c",
548 fmt::format(
"const size_t {cat_begin_name}[] = {{\n" 549 "{array_cat_begin}\n" 551 "cat_begin_name"_a = cat_begin_name,
552 "array_cat_begin"_a = array_cat_begin), 0);
555 if (array_nodes.empty()) {
558 fmt::format(
"nid = -1;\n" 559 "{output_switch_statement}\n",
560 "output_switch_statement"_a
561 = output_switch_statement), indent);
562 }
else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
564 fmt::format(native::eval_loop_template,
565 "node_array_name"_a = node_array_name,
566 "cat_bitmap_name"_a = cat_bitmap_name,
567 "cat_begin_name"_a = cat_begin_name,
568 "data_field"_a = (param_.
quantize > 0 ?
"qvalue" :
"fvalue"),
569 "comp_op"_a =
OpName(common_comp_op),
570 "output_switch_statement"_a
571 = output_switch_statement), indent);
574 fmt::format(native::eval_loop_template_without_categorical_feature,
575 "node_array_name"_a = node_array_name,
576 "data_field"_a = (param_.
quantize > 0 ?
"qvalue" :
"fvalue"),
577 "comp_op"_a =
OpName(common_comp_op),
578 "output_switch_statement"_a
579 = output_switch_statement), indent);
583 template <
typename ThresholdType>
586 const std::string threshold_type
587 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
589 if (node->quantized) {
590 std::string lhs = fmt::format(
"data[{split_index}].qvalue",
591 "split_index"_a = node->split_index);
592 result = fmt::format(
"{lhs} {opname} {threshold}",
594 "opname"_a =
OpName(node->op),
595 "threshold"_a = node->threshold.int_val);
596 }
else if (std::isinf(node->threshold.float_val)) {
599 result = (
CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
602 std::string lhs = fmt::format(
"data[{split_index}].fvalue",
603 "split_index"_a = node->split_index);
605 = fmt::format(
"{lhs} {opname} ({threshold_type}){threshold}",
607 "opname"_a =
OpName(node->op),
608 "threshold_type"_a = threshold_type,
609 "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
617 std::vector<uint64_t> bitmap
618 = common_util::GetCategoricalBitmap(node->matching_categories);
619 TREELITE_CHECK_GE(bitmap.size(), 1);
620 bool all_zeros =
true;
621 for (uint64_t e : bitmap) {
622 all_zeros &= (e == 0);
627 std::ostringstream oss;
628 const std::string right_categories_flag = (node->categories_list_right_child ?
"!" :
"");
629 if (node->default_left) {
631 "data[{split_index}].missing == -1 || {right_categories_flag}(" 632 "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
633 "split_index"_a = node->split_index,
634 "right_categories_flag"_a = right_categories_flag);
637 "data[{split_index}].missing != -1 && {right_categories_flag}(" 638 "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
639 "split_index"_a = node->split_index,
640 "right_categories_flag"_a = right_categories_flag);
642 oss <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 643 << bitmap[0] <<
"U >> tmp) & 1) )";
644 for (
size_t i = 1; i < bitmap.size(); ++i) {
645 oss <<
" || (tmp >= " << (i * 64)
646 <<
" && tmp < " << ((i + 1) * 64)
647 <<
" && (( (uint64_t)" << bitmap[i]
648 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
657 RenderIsCategoricalArray(
const std::vector<bool>& is_categorical) {
659 for (
int fid = 0; fid < num_feature_; ++fid) {
660 formatter << (is_categorical[fid] ? 1 : 0);
662 return formatter.
str();
665 template <
typename LeafOutputType>
667 const std::string leaf_output_type
668 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
669 std::string output_statement;
671 if (node->is_vector) {
673 TREELITE_CHECK_EQ(node->vector.size(),
static_cast<size_t>(task_param_.
num_class))
674 <<
"Ill-formed model: leaf vector must be of length [num_class]";
675 for (
size_t group_id = 0; group_id < task_param_.
num_class; ++group_id) {
677 += fmt::format(
"sum[{group_id}] += ({leaf_output_type}){output};\n",
678 "group_id"_a = group_id,
679 "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]),
680 "leaf_output_type"_a = leaf_output_type);
685 = fmt::format(
"sum[{group_id}] += ({leaf_output_type}){output};\n",
686 "group_id"_a = node->tree_id % task_param_.
num_class,
687 "output"_a = common_util::ToStringHighPrecision(node->scalar),
688 "leaf_output_type"_a = leaf_output_type);
692 = fmt::format(
"sum += ({leaf_output_type}){output};\n",
693 "output"_a = common_util::ToStringHighPrecision(node->scalar),
694 "leaf_output_type"_a = leaf_output_type);
696 return output_statement;
700 ASTNativeCompiler::ASTNativeCompiler(
const CompilerParam& param)
701 : pimpl_(std::make_unique<ASTNativeCompilerImpl>(param)) {
703 TREELITE_LOG(INFO) <<
"Using ASTNativeCompiler";
706 TREELITE_LOG(INFO) <<
"Warning: 'dump_array_as_elf' parameter is not applicable " 707 "for ASTNativeCompiler";
711 ASTNativeCompiler::~ASTNativeCompiler() =
default;
714 ASTNativeCompiler::Compile(
const Model &model) {
715 return pimpl_->Compile(model);
719 ASTNativeCompiler::QueryParam()
const {
720 return pimpl_->QueryParam();
ModelParam param
extra parameters
Parameters for tree compiler.
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
std::string OpName(Operator op)
get string representation of comparison operator
TaskType
Enum type representing the task type.
parameters for tree compiler
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
unsigned int num_class
The number of classes in the target label.
TaskType task_type
Task type.
Interface of compiler that compiles a tree ensemble model.
template for main function
std::vector< std::vector< uint64_t > > Get() const
fetch branch annotation. Usage example:
template for evaluation logic for folded code
code template for QuantizerNode
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Look up C symbols corresponding to TypeInfo.
Utilities for code folding.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
TaskParam task_param
Group of parameters that are specific to the particular task type.
thin wrapper for tree ensemble model
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
void Load(std::istream &fi)
load branch annotation from a JSON file
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators