10 #include <fmt/format.h> 11 #include <rapidjson/stringbuffer.h> 12 #include <rapidjson/writer.h> 15 #include <unordered_map> 21 #include "./pred_transform.h" 31 #if defined(_MSC_VER) || defined(_WIN32) 32 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 34 #define DLLEXPORT_KEYWORD "" 46 template <
typename ThresholdType,
typename LeafOutputType>
49 cm.backend =
"native";
51 TREELITE_CHECK(model.
task_type != TaskType::kMultiClfCategLeaf)
52 <<
"Model task type unsupported by ASTNativeCompiler";
53 TREELITE_CHECK(model.
task_param.output_type == TaskParam::OutputType::kFloat)
54 <<
"ASTNativeCompiler only supports models with float output";
59 pred_transform_ = model.
param.pred_transform;
60 sigmoid_alpha_ = model.
param.sigmoid_alpha;
61 ratio_c_ = model.
param.ratio_c;
62 global_bias_ = model.
param.global_bias;
66 builder.BuildAST(model);
67 if (builder.FoldCode(param_.code_folding_req) || param_.quantize > 0) {
70 = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
72 if (param_.annotate_in !=
"NULL") {
74 std::ifstream fi(param_.annotate_in.c_str());
76 const auto annotation = annotator.
Get();
77 builder.LoadDataCounts(annotation);
78 TREELITE_LOG(INFO) <<
"Loading node frequencies from `" 79 << param_.annotate_in <<
"'";
81 builder.Split(param_.parallel_comp);
82 if (param_.quantize > 0) {
83 builder.QuantizeThresholds();
87 const char* destfile = getenv(
"TREELITE_DUMP_AST");
89 std::ofstream os(destfile);
90 os << builder.GetDump() << std::endl;
94 WalkAST<ThresholdType, LeafOutputType>(builder.GetRootNode(),
"main.c", 0);
95 if (files_.count(
"arrays.c") > 0) {
96 PrependToBuffer(
"arrays.c",
"#include \"header.h\"\n", 0);
101 rapidjson::StringBuffer os;
102 rapidjson::Writer<rapidjson::StringBuffer> writer(os);
104 writer.StartObject();
105 writer.Key(
"target");
106 writer.String(param_.native_lib_name.data(), param_.native_lib_name.size());
107 writer.Key(
"sources");
109 for (
const auto& kv : files_) {
110 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
111 const size_t line_count
112 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
113 writer.StartObject();
115 std::string name = kv.first.substr(0, kv.first.length() - 2);
116 writer.String(name.data(), name.size());
117 writer.Key(
"length");
118 writer.Uint64(line_count);
127 cm.files = std::move(files_);
132 TREELITE_CHECK(model.GetLeafOutputType() != TypeInfo::kUInt32)
133 <<
"Integer leaf outputs not yet supported";
134 this->pred_tranform_func_ = PredTransformFunction(
"native", model);
135 return model.Dispatch([
this](
const auto& model_handle) {
136 return this->CompileImpl(model_handle);
149 std::string pred_transform_;
150 float sigmoid_alpha_;
153 std::string pred_tranform_func_;
154 std::string array_is_categorical_;
155 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
157 template <
typename ThresholdType,
typename LeafOutputType>
158 void WalkAST(
const ASTNode* node,
159 const std::string& dest,
168 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
169 HandleMainNode<ThresholdType, LeafOutputType>(t1, dest, indent);
170 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
171 HandleACNode<ThresholdType, LeafOutputType>(t2, dest, indent);
172 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
173 HandleCondNode<ThresholdType, LeafOutputType>(t3, dest, indent);
175 HandleOutputNode<ThresholdType, LeafOutputType>(t4, dest, indent);
176 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
177 HandleTUNode<ThresholdType, LeafOutputType>(t5, dest, indent);
179 HandleQNode<ThresholdType, LeafOutputType>(t6, dest, indent);
180 }
else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
181 HandleCodeFolderNode<ThresholdType, LeafOutputType>(t7, dest, indent);
183 TREELITE_LOG(FATAL) <<
"Unrecognized AST node type";
188 inline void AppendToBuffer(
const std::string& dest,
189 const std::string& content,
191 files_[dest].content += common_util::IndentMultiLineString(content, indent);
195 inline void PrependToBuffer(
const std::string& dest,
196 const std::string& content,
199 = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
202 template <
typename ThresholdType,
typename LeafOutputType>
203 void HandleMainNode(
const MainNode* node,
204 const std::string& dest,
206 const std::string threshold_type
207 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
208 const std::string leaf_output_type
209 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
210 const std::string predict_function_signature
212 fmt::format(
"size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)",
214 : fmt::format(
"{} predict(union Entry* data, int pred_margin)",
217 if (!array_is_categorical_.empty()) {
218 array_is_categorical_
219 = fmt::format(
"const unsigned char is_categorical[] = {{\n{}\n}}",
220 array_is_categorical_);
223 const std::string query_functions_definition
224 = fmt::format(native::query_functions_definition_template,
226 "num_feature"_a = num_feature_,
227 "pred_transform"_a = pred_transform_,
228 "sigmoid_alpha"_a = sigmoid_alpha_,
229 "ratio_c"_a = ratio_c_,
230 "global_bias"_a = global_bias_,
235 fmt::format(native::main_start_template,
236 "array_is_categorical"_a = array_is_categorical_,
237 "query_functions_definition"_a = query_functions_definition,
238 "pred_transform_function"_a = pred_tranform_func_,
239 "predict_function_signature"_a = predict_function_signature),
241 const std::string query_functions_prototype
242 = fmt::format(native::query_functions_prototype_template,
243 "dllexport"_a = DLLEXPORT_KEYWORD);
244 AppendToBuffer(
"header.h",
245 fmt::format(native::header_template,
246 "dllexport"_a = DLLEXPORT_KEYWORD,
247 "predict_function_signature"_a = predict_function_signature,
248 "query_functions_prototype"_a = query_functions_prototype,
249 "threshold_type"_a = threshold_type,
250 "threshold_type_Node"_a = (param_.
quantize > 0 ? std::string(
"int") : threshold_type)),
253 TREELITE_CHECK_EQ(node->children.size(), 1);
254 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
256 std::string optional_average_field;
257 if (node->average_result) {
258 if (task_type_ == TaskType::kMultiClfGrovePerClass) {
261 TREELITE_CHECK_GT(task_param_.
num_class, 1);
262 TREELITE_CHECK_EQ(node->num_tree % task_param_.
num_class, 0)
263 <<
"Expected the number of trees to be divisible by the number of classes";
264 int num_boosting_round = node->num_tree /
static_cast<int>(task_param_.
num_class);
265 optional_average_field = fmt::format(
" / {}", num_boosting_round);
267 TREELITE_CHECK(task_type_ == TaskType::kBinaryClfRegr
268 || task_type_ == TaskType::kMultiClfProbDistLeaf);
271 optional_average_field = fmt::format(
" / {}", node->num_tree);
276 fmt::format(native::main_end_multiclass_template,
278 "optional_average_field"_a = optional_average_field,
279 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
280 "leaf_output_type"_a = leaf_output_type),
284 fmt::format(native::main_end_template,
285 "optional_average_field"_a = optional_average_field,
286 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
287 "leaf_output_type"_a = leaf_output_type),
292 template <
typename ThresholdType,
typename LeafOutputType>
294 const std::string& dest,
296 const std::string leaf_output_type
297 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
300 fmt::format(
"{leaf_output_type} sum[{num_class}] = {{0}};\n" 301 "unsigned int tmp;\n" 302 "int nid, cond, fid; /* used for folded subtrees */\n",
304 "leaf_output_type"_a = leaf_output_type), indent);
307 fmt::format(
"{leaf_output_type} sum = ({leaf_output_type})0;\n" 308 "unsigned int tmp;\n" 309 "int nid, cond, fid; /* used for folded subtrees */\n",
310 "leaf_output_type"_a = leaf_output_type),
313 for (
ASTNode* child : node->children) {
314 WalkAST<ThresholdType, LeafOutputType>(child, dest, indent);
318 template <
typename ThresholdType,
typename LeafOutputType>
320 const std::string& dest,
323 std::string condition_with_na_check;
326 std::string condition = ExtractNumericalCondition(t);
327 const char* condition_with_na_check_template
328 = (node->default_left) ?
329 "!(data[{split_index}].missing != -1) || ({condition})" 330 :
" (data[{split_index}].missing != -1) && ({condition})";
331 condition_with_na_check
332 = fmt::format(condition_with_na_check_template,
333 "split_index"_a = node->split_index,
334 "condition"_a = condition);
338 condition_with_na_check = ExtractCategoricalCondition(t2);
340 if (node->children[0]->data_count && node->children[1]->data_count) {
341 const uint64_t left_freq = *node->children[0]->data_count;
342 const uint64_t right_freq = *node->children[1]->data_count;
343 condition_with_na_check
344 = fmt::format(
" {keyword}( {condition} ) ",
345 "keyword"_a = ((left_freq > right_freq) ?
"LIKELY" :
"UNLIKELY"),
346 "condition"_a = condition_with_na_check);
349 fmt::format(
"if ({}) {{\n", condition_with_na_check), indent);
350 TREELITE_CHECK_EQ(node->children.size(), 2);
351 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
352 AppendToBuffer(dest,
"} else {\n", indent);
353 WalkAST<ThresholdType, LeafOutputType>(node->children[1], dest, indent + 2);
354 AppendToBuffer(dest,
"}\n", indent);
357 template <
typename ThresholdType,
typename LeafOutputType>
359 const std::string& dest,
361 AppendToBuffer(dest, RenderOutputStatement(node), indent);
362 TREELITE_CHECK_EQ(node->children.size(), 0);
365 template <
typename ThresholdType,
typename LeafOutputType>
367 const std::string& dest,
369 const int unit_id = node->unit_id;
370 const std::string new_file = fmt::format(
"tu{}.c", unit_id);
371 const std::string leaf_output_type
372 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
374 std::string unit_function_name, unit_function_signature,
375 unit_function_call_signature;
378 = fmt::format(
"predict_margin_multiclass_unit{}", unit_id);
379 unit_function_signature
380 = fmt::format(
"void {function_name}(union Entry* data, {leaf_output_type}* result)",
381 "function_name"_a = unit_function_name,
382 "leaf_output_type"_a = leaf_output_type);
383 unit_function_call_signature
384 = fmt::format(
"{}(data, sum);\n", unit_function_name);
387 = fmt::format(
"predict_margin_unit{}", unit_id);
388 unit_function_signature
389 = fmt::format(
"{leaf_output_type} {function_name}(union Entry* data)",
390 "function_name"_a = unit_function_name,
391 "leaf_output_type"_a = leaf_output_type);
392 unit_function_call_signature
393 = fmt::format(
"sum += {}(data);\n", unit_function_name);
395 AppendToBuffer(dest, unit_function_call_signature, indent);
396 AppendToBuffer(new_file,
397 fmt::format(
"#include \"header.h\"\n" 398 "{} {{\n", unit_function_signature), 0);
399 TREELITE_CHECK_EQ(node->children.size(), 1);
400 WalkAST<ThresholdType, LeafOutputType>(node->children[0], new_file, 2);
402 AppendToBuffer(new_file,
403 fmt::format(
" for (int i = 0; i < {num_class}; ++i) {{\n" 404 " result[i] += sum[i];\n" 407 "num_class"_a = task_param_.
num_class), 0);
409 AppendToBuffer(new_file,
" return sum;\n}\n", 0);
411 AppendToBuffer(
"header.h", fmt::format(
"{};\n", unit_function_signature), 0);
414 template <
typename ThresholdType,
typename LeafOutputType>
416 const std::string& dest,
418 const std::string threshold_type
419 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
421 std::string array_threshold, array_th_begin, array_th_len;
426 size_t total_num_threshold;
430 for (
const auto& e : node->cut_pts) {
437 array_threshold = formatter.
str();
442 for (
const auto& e : node->cut_pts) {
446 total_num_threshold = accum;
447 array_th_begin = formatter.
str();
451 for (
const auto& e : node->cut_pts) {
452 formatter << e.size();
454 array_th_len = formatter.
str();
456 if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
457 PrependToBuffer(dest,
458 fmt::format(native::qnode_template,
459 "total_num_threshold"_a = total_num_threshold,
460 "threshold_type"_a = threshold_type),
463 fmt::format(native::quantize_loop_template,
464 "num_feature"_a = num_feature_), indent);
466 if (!array_threshold.empty()) {
467 PrependToBuffer(dest,
468 fmt::format(
"static const {threshold_type} threshold[] = {{\n" 469 "{array_threshold}\n" 471 "array_threshold"_a = array_threshold,
472 "threshold_type"_a = threshold_type),
475 if (!array_th_begin.empty()) {
476 PrependToBuffer(dest,
477 fmt::format(
"static const int th_begin[] = {{\n" 479 "}};\n",
"array_th_begin"_a = array_th_begin), 0);
481 if (!array_th_len.empty()) {
482 PrependToBuffer(dest,
483 fmt::format(
"static const int th_len[] = {{\n" 485 "}};\n",
"array_th_len"_a = array_th_len), 0);
487 TREELITE_CHECK_EQ(node->children.size(), 1);
488 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent);
491 template <
typename ThresholdType,
typename LeafOutputType>
493 const std::string& dest,
495 TREELITE_CHECK_EQ(node->children.size(), 1);
496 const int node_id = node->children[0]->node_id;
497 const int tree_id = node->children[0]->tree_id;
500 std::string array_nodes, array_cat_bitmap, array_cat_begin;
502 const std::string node_array_name
503 = fmt::format(
"node_tree{}_node{}", tree_id, node_id);
507 const std::string cat_bitmap_name
508 = fmt::format(
"cat_bitmap_tree{}_node{}", tree_id, node_id);
512 const std::string cat_begin_name
513 = fmt::format(
"cat_begin_tree{}_node{}", tree_id, node_id);
515 std::string output_switch_statement;
517 common_util::RenderCodeFolderArrays<ThresholdType, LeafOutputType>(node, param_.
quantize,
518 false,
"{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
520 &array_nodes, &array_cat_bitmap, &array_cat_begin, &output_switch_statement,
522 if (!array_nodes.empty()) {
523 AppendToBuffer(
"header.h",
524 fmt::format(
"extern const struct Node {node_array_name}[];\n",
525 "node_array_name"_a = node_array_name), 0);
526 AppendToBuffer(
"arrays.c",
527 fmt::format(
"const struct Node {node_array_name}[] = {{\n" 530 "node_array_name"_a = node_array_name,
531 "array_nodes"_a = array_nodes), 0);
534 if (!array_cat_bitmap.empty()) {
535 AppendToBuffer(
"header.h",
536 fmt::format(
"extern const uint64_t {cat_bitmap_name}[];\n",
537 "cat_bitmap_name"_a = cat_bitmap_name), 0);
538 AppendToBuffer(
"arrays.c",
539 fmt::format(
"const uint64_t {cat_bitmap_name}[] = {{\n" 540 "{array_cat_bitmap}\n" 542 "cat_bitmap_name"_a = cat_bitmap_name,
543 "array_cat_bitmap"_a = array_cat_bitmap), 0);
546 if (!array_cat_begin.empty()) {
547 AppendToBuffer(
"header.h",
548 fmt::format(
"extern const size_t {cat_begin_name}[];\n",
549 "cat_begin_name"_a = cat_begin_name), 0);
550 AppendToBuffer(
"arrays.c",
551 fmt::format(
"const size_t {cat_begin_name}[] = {{\n" 552 "{array_cat_begin}\n" 554 "cat_begin_name"_a = cat_begin_name,
555 "array_cat_begin"_a = array_cat_begin), 0);
558 if (array_nodes.empty()) {
561 fmt::format(
"nid = -1;\n" 562 "{output_switch_statement}\n",
563 "output_switch_statement"_a
564 = output_switch_statement), indent);
565 }
else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
567 fmt::format(native::eval_loop_template,
568 "node_array_name"_a = node_array_name,
569 "cat_bitmap_name"_a = cat_bitmap_name,
570 "cat_begin_name"_a = cat_begin_name,
571 "data_field"_a = (param_.
quantize > 0 ?
"qvalue" :
"fvalue"),
572 "comp_op"_a =
OpName(common_comp_op),
573 "output_switch_statement"_a
574 = output_switch_statement), indent);
577 fmt::format(native::eval_loop_template_without_categorical_feature,
578 "node_array_name"_a = node_array_name,
579 "data_field"_a = (param_.
quantize > 0 ?
"qvalue" :
"fvalue"),
580 "comp_op"_a =
OpName(common_comp_op),
581 "output_switch_statement"_a
582 = output_switch_statement), indent);
586 template <
typename ThresholdType>
589 const std::string threshold_type
590 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
592 if (node->quantized) {
593 std::string lhs = fmt::format(
"data[{split_index}].qvalue",
594 "split_index"_a = node->split_index);
595 result = fmt::format(
"{lhs} {opname} {threshold}",
597 "opname"_a =
OpName(node->op),
598 "threshold"_a = node->threshold.int_val);
599 }
else if (std::isinf(node->threshold.float_val)) {
602 result = (
CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
605 std::string lhs = fmt::format(
"data[{split_index}].fvalue",
606 "split_index"_a = node->split_index);
608 = fmt::format(
"{lhs} {opname} ({threshold_type}){threshold}",
610 "opname"_a =
OpName(node->op),
611 "threshold_type"_a = threshold_type,
612 "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
620 std::vector<uint64_t> bitmap
621 = common_util::GetCategoricalBitmap(node->matching_categories);
622 TREELITE_CHECK_GE(bitmap.size(), 1);
623 bool all_zeros =
true;
624 for (uint64_t e : bitmap) {
625 all_zeros &= (e == 0);
630 std::ostringstream oss;
631 const std::string right_categories_flag = (node->categories_list_right_child ?
"!" :
"");
632 if (node->default_left) {
634 "data[{split_index}].missing == -1 || {right_categories_flag}(" 635 "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
636 "split_index"_a = node->split_index,
637 "right_categories_flag"_a = right_categories_flag);
640 "data[{split_index}].missing != -1 && {right_categories_flag}(" 641 "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
642 "split_index"_a = node->split_index,
643 "right_categories_flag"_a = right_categories_flag);
647 "((data[{split_index}].fvalue >= 0) && " 648 "(fabsf(data[{split_index}].fvalue) <= (float)(1U << FLT_MANT_DIG)) && (",
649 "split_index"_a = node->split_index);
650 oss <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 651 << bitmap[0] <<
"U >> tmp) & 1) )";
652 for (
size_t i = 1; i < bitmap.size(); ++i) {
653 oss <<
" || (tmp >= " << (i * 64)
654 <<
" && tmp < " << ((i + 1) * 64)
655 <<
" && (( (uint64_t)" << bitmap[i]
656 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
665 RenderIsCategoricalArray(
const std::vector<bool>& is_categorical) {
667 for (
int fid = 0; fid < num_feature_; ++fid) {
668 formatter << (is_categorical[fid] ? 1 : 0);
670 return formatter.
str();
673 template <
typename LeafOutputType>
675 const std::string leaf_output_type
676 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
677 std::string output_statement;
679 if (node->is_vector) {
681 TREELITE_CHECK_EQ(node->vector.size(),
static_cast<size_t>(task_param_.
num_class))
682 <<
"Ill-formed model: leaf vector must be of length [num_class]";
683 for (
size_t group_id = 0; group_id < task_param_.
num_class; ++group_id) {
685 += fmt::format(
"sum[{group_id}] += ({leaf_output_type}){output};\n",
686 "group_id"_a = group_id,
687 "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]),
688 "leaf_output_type"_a = leaf_output_type);
693 = fmt::format(
"sum[{group_id}] += ({leaf_output_type}){output};\n",
694 "group_id"_a = node->tree_id % task_param_.
num_class,
695 "output"_a = common_util::ToStringHighPrecision(node->scalar),
696 "leaf_output_type"_a = leaf_output_type);
700 = fmt::format(
"sum += ({leaf_output_type}){output};\n",
701 "output"_a = common_util::ToStringHighPrecision(node->scalar),
702 "leaf_output_type"_a = leaf_output_type);
704 return output_statement;
708 ASTNativeCompiler::ASTNativeCompiler(
const CompilerParam& param)
709 : pimpl_(std::make_unique<ASTNativeCompilerImpl>(param)) {
711 TREELITE_LOG(INFO) <<
"Using ASTNativeCompiler";
714 TREELITE_LOG(INFO) <<
"Warning: 'dump_array_as_elf' parameter is not applicable " 715 "for ASTNativeCompiler";
719 ASTNativeCompiler::~ASTNativeCompiler() =
default;
722 ASTNativeCompiler::Compile(
const Model &model) {
723 return pimpl_->Compile(model);
727 ASTNativeCompiler::QueryParam()
const {
728 return pimpl_->QueryParam();
ModelParam param
extra parameters
Parameters for tree compiler.
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
std::string OpName(Operator op)
get string representation of comparison operator
TaskType
Enum type representing the task type.
parameters for tree compiler
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
unsigned int num_class
The number of classes in the target label.
TaskType task_type
Task type.
Interface of compiler that compiles a tree ensemble model.
int32_t num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
template for main function
std::vector< std::vector< uint64_t > > Get() const
fetch branch annotation. Usage example:
template for evaluation logic for folded code
code template for QuantizerNode
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Look up C symbols corresponding to TypeInfo.
Utilities for code folding.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
TaskParam task_param
Group of parameters that are specific to the particular task type.
thin wrapper for tree ensemble model
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
void Load(std::istream &fi)
load branch annotation from a JSON file
Operator
comparison operators