10 #include <fmt/format.h> 13 #include <unordered_map> 16 #include "./pred_transform.h" 27 #if defined(_MSC_VER) || defined(_WIN32) 28 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 30 #define DLLEXPORT_KEYWORD "" 38 DMLC_REGISTRY_FILE_TAG(ast_native);
45 LOG(INFO) <<
"Using ASTNativeCompiler";
48 LOG(INFO) <<
"Warning: 'dump_array_as_elf' parameter is not applicable " 49 "for ASTNativeCompiler";
53 template <
typename ThresholdType,
typename LeafOutputType>
56 cm.backend =
"native";
58 CHECK(model.
task_type != TaskType::kMultiClfCategLeaf)
59 <<
"Model task type unsupported by ASTNativeCompiler";
60 CHECK(model.
task_param.output_type == TaskParameter::OutputType::kFloat)
61 <<
"ASTNativeCompiler only supports models with float output";
66 pred_transform_ = model.
param.pred_transform;
67 sigmoid_alpha_ = model.
param.sigmoid_alpha;
68 global_bias_ = model.
param.global_bias;
72 builder.BuildAST(model);
76 = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
80 std::unique_ptr<dmlc::Stream> fi(
81 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
82 annotator.
Load(fi.get());
83 const auto annotation = annotator.
Get();
84 builder.LoadDataCounts(annotation);
85 LOG(INFO) <<
"Loading node frequencies from `" 90 builder.QuantizeThresholds();
94 const char* destfile = getenv(
"TREELITE_DUMP_AST");
96 std::ofstream os(destfile);
97 os << builder.GetDump() << std::endl;
101 WalkAST<ThresholdType, LeafOutputType>(builder.GetRootNode(),
"main.c", 0);
102 if (files_.count(
"arrays.c") > 0) {
103 PrependToBuffer(
"arrays.c",
"#include \"header.h\"\n", 0);
108 std::vector<std::unordered_map<std::string, std::string>> source_list;
109 for (
const auto& kv : files_) {
110 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
111 const size_t line_count
112 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
113 source_list.push_back({ {
"name",
114 kv.first.substr(0, kv.first.length() - 2)},
115 {
"length", std::to_string(line_count)} });
118 std::ostringstream oss;
119 std::unique_ptr<dmlc::JSONWriter> writer(
new dmlc::JSONWriter(&oss));
120 writer->BeginObject();
122 writer->WriteObjectKeyValue(
"sources", source_list);
126 cm.files = std::move(files_);
131 CHECK(model.GetLeafOutputType() != TypeInfo::kUInt32)
132 <<
"Integer leaf outputs not yet supported";
133 this->pred_tranform_func_ = PredTransformFunction(
"native", model);
134 return model.Dispatch([
this](
const auto& model_handle) {
135 return this->CompileImpl(model_handle);
144 std::string pred_transform_;
145 float sigmoid_alpha_;
147 std::string pred_tranform_func_;
148 std::string array_is_categorical_;
149 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
151 template <
typename ThresholdType,
typename LeafOutputType>
152 void WalkAST(
const ASTNode* node,
153 const std::string& dest,
162 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
163 HandleMainNode<ThresholdType, LeafOutputType>(t1, dest, indent);
164 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
165 HandleACNode<ThresholdType, LeafOutputType>(t2, dest, indent);
166 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
167 HandleCondNode<ThresholdType, LeafOutputType>(t3, dest, indent);
169 HandleOutputNode<ThresholdType, LeafOutputType>(t4, dest, indent);
170 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
171 HandleTUNode<ThresholdType, LeafOutputType>(t5, dest, indent);
173 HandleQNode<ThresholdType, LeafOutputType>(t6, dest, indent);
174 }
else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
175 HandleCodeFolderNode<ThresholdType, LeafOutputType>(t7, dest, indent);
177 LOG(FATAL) <<
"Unrecognized AST node type";
182 inline void AppendToBuffer(
const std::string& dest,
183 const std::string& content,
185 files_[dest].content += common_util::IndentMultiLineString(content, indent);
189 inline void PrependToBuffer(
const std::string& dest,
190 const std::string& content,
193 = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
196 template <
typename ThresholdType,
typename LeafOutputType>
197 void HandleMainNode(
const MainNode* node,
198 const std::string& dest,
200 const std::string threshold_type
201 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
202 const std::string leaf_output_type
203 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
204 const std::string predict_function_signature
206 fmt::format(
"size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)",
208 : fmt::format(
"{} predict(union Entry* data, int pred_margin)",
211 if (!array_is_categorical_.empty()) {
212 array_is_categorical_
213 = fmt::format(
"const unsigned char is_categorical[] = {{\n{}\n}}",
214 array_is_categorical_);
217 const std::string query_functions_definition
218 = fmt::format(native::query_functions_definition_template,
220 "num_feature"_a = num_feature_,
221 "pred_transform"_a = pred_transform_,
222 "sigmoid_alpha"_a = sigmoid_alpha_,
223 "global_bias"_a = global_bias_,
228 fmt::format(native::main_start_template,
229 "array_is_categorical"_a = array_is_categorical_,
230 "query_functions_definition"_a = query_functions_definition,
231 "pred_transform_function"_a = pred_tranform_func_,
232 "predict_function_signature"_a = predict_function_signature),
234 const std::string query_functions_prototype
235 = fmt::format(native::query_functions_prototype_template,
236 "dllexport"_a = DLLEXPORT_KEYWORD);
237 AppendToBuffer(
"header.h",
238 fmt::format(native::header_template,
239 "dllexport"_a = DLLEXPORT_KEYWORD,
240 "predict_function_signature"_a = predict_function_signature,
241 "query_functions_prototype"_a = query_functions_prototype,
242 "threshold_type"_a = threshold_type,
243 "threshold_type_Node"_a = (param.
quantize > 0 ? std::string(
"int") : threshold_type)),
246 CHECK_EQ(node->children.size(), 1);
247 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
249 std::string optional_average_field;
250 if (node->average_result) {
251 if (task_type_ == TaskType::kMultiClfGrovePerClass) {
255 CHECK_EQ(node->num_tree % task_param_.
num_class, 0)
256 <<
"Expected the number of trees to be divisible by the number of classes";
257 int num_boosting_round = node->num_tree /
static_cast<int>(task_param_.
num_class);
258 optional_average_field = fmt::format(
" / {}", num_boosting_round);
260 CHECK(task_type_ == TaskType::kBinaryClfRegr
261 || task_type_ == TaskType::kMultiClfProbDistLeaf);
264 optional_average_field = fmt::format(
" / {}", node->num_tree);
269 fmt::format(native::main_end_multiclass_template,
271 "optional_average_field"_a = optional_average_field,
272 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
273 "leaf_output_type"_a = leaf_output_type),
277 fmt::format(native::main_end_template,
278 "optional_average_field"_a = optional_average_field,
279 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias),
280 "leaf_output_type"_a = leaf_output_type),
285 template <
typename ThresholdType,
typename LeafOutputType>
287 const std::string& dest,
289 const std::string leaf_output_type
290 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
293 fmt::format(
"{leaf_output_type} sum[{num_class}] = {{0}};\n" 294 "unsigned int tmp;\n" 295 "int nid, cond, fid; /* used for folded subtrees */\n",
297 "leaf_output_type"_a = leaf_output_type), indent);
300 fmt::format(
"{leaf_output_type} sum = ({leaf_output_type})0;\n" 301 "unsigned int tmp;\n" 302 "int nid, cond, fid; /* used for folded subtrees */\n",
303 "leaf_output_type"_a = leaf_output_type),
306 for (
ASTNode* child : node->children) {
307 WalkAST<ThresholdType, LeafOutputType>(child, dest, indent);
311 template <
typename ThresholdType,
typename LeafOutputType>
313 const std::string& dest,
316 std::string condition_with_na_check;
319 std::string condition = ExtractNumericalCondition(t);
320 const char* condition_with_na_check_template
321 = (node->default_left) ?
322 "!(data[{split_index}].missing != -1) || ({condition})" 323 :
" (data[{split_index}].missing != -1) && ({condition})";
324 condition_with_na_check
325 = fmt::format(condition_with_na_check_template,
326 "split_index"_a = node->split_index,
327 "condition"_a = condition);
331 condition_with_na_check = ExtractCategoricalCondition(t2);
333 if (node->children[0]->data_count && node->children[1]->data_count) {
334 const size_t left_freq = node->children[0]->data_count.value();
335 const size_t right_freq = node->children[1]->data_count.value();
336 condition_with_na_check
337 = fmt::format(
" {keyword}( {condition} ) ",
338 "keyword"_a = ((left_freq > right_freq) ?
"LIKELY" :
"UNLIKELY"),
339 "condition"_a = condition_with_na_check);
342 fmt::format(
"if ({}) {{\n", condition_with_na_check), indent);
343 CHECK_EQ(node->children.size(), 2);
344 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent + 2);
345 AppendToBuffer(dest,
"} else {\n", indent);
346 WalkAST<ThresholdType, LeafOutputType>(node->children[1], dest, indent + 2);
347 AppendToBuffer(dest,
"}\n", indent);
350 template <
typename ThresholdType,
typename LeafOutputType>
352 const std::string& dest,
354 AppendToBuffer(dest, RenderOutputStatement(node), indent);
355 CHECK_EQ(node->children.size(), 0);
358 template <
typename ThresholdType,
typename LeafOutputType>
360 const std::string& dest,
362 const int unit_id = node->unit_id;
363 const std::string new_file = fmt::format(
"tu{}.c", unit_id);
364 const std::string leaf_output_type
365 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
367 std::string unit_function_name, unit_function_signature,
368 unit_function_call_signature;
371 = fmt::format(
"predict_margin_multiclass_unit{}", unit_id);
372 unit_function_signature
373 = fmt::format(
"void {function_name}(union Entry* data, {leaf_output_type}* result)",
374 "function_name"_a = unit_function_name,
375 "leaf_output_type"_a = leaf_output_type);
376 unit_function_call_signature
377 = fmt::format(
"{}(data, sum);\n", unit_function_name);
380 = fmt::format(
"predict_margin_unit{}", unit_id);
381 unit_function_signature
382 = fmt::format(
"{leaf_output_type} {function_name}(union Entry* data)",
383 "function_name"_a = unit_function_name,
384 "leaf_output_type"_a = leaf_output_type);
385 unit_function_call_signature
386 = fmt::format(
"sum += {}(data);\n", unit_function_name);
388 AppendToBuffer(dest, unit_function_call_signature, indent);
389 AppendToBuffer(new_file,
390 fmt::format(
"#include \"header.h\"\n" 391 "{} {{\n", unit_function_signature), 0);
392 CHECK_EQ(node->children.size(), 1);
393 WalkAST<ThresholdType, LeafOutputType>(node->children[0], new_file, 2);
395 AppendToBuffer(new_file,
396 fmt::format(
" for (int i = 0; i < {num_class}; ++i) {{\n" 397 " result[i] += sum[i];\n" 400 "num_class"_a = task_param_.
num_class), 0);
402 AppendToBuffer(new_file,
" return sum;\n}\n", 0);
404 AppendToBuffer(
"header.h", fmt::format(
"{};\n", unit_function_signature), 0);
407 template <
typename ThresholdType,
typename LeafOutputType>
409 const std::string& dest,
411 const std::string threshold_type
412 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
414 std::string array_threshold, array_th_begin, array_th_len;
419 size_t total_num_threshold;
423 for (
const auto& e : node->cut_pts) {
430 array_threshold = formatter.
str();
435 for (
const auto& e : node->cut_pts) {
439 total_num_threshold = accum;
440 array_th_begin = formatter.
str();
444 for (
const auto& e : node->cut_pts) {
445 formatter << e.size();
447 array_th_len = formatter.
str();
449 if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
450 PrependToBuffer(dest,
451 fmt::format(native::qnode_template,
452 "total_num_threshold"_a = total_num_threshold,
453 "threshold_type"_a = threshold_type),
456 fmt::format(native::quantize_loop_template,
457 "num_feature"_a = num_feature_), indent);
459 if (!array_threshold.empty()) {
460 PrependToBuffer(dest,
461 fmt::format(
"static const {threshold_type} threshold[] = {{\n" 462 "{array_threshold}\n" 464 "array_threshold"_a = array_threshold,
465 "threshold_type"_a = threshold_type),
468 if (!array_th_begin.empty()) {
469 PrependToBuffer(dest,
470 fmt::format(
"static const int th_begin[] = {{\n" 472 "}};\n",
"array_th_begin"_a = array_th_begin), 0);
474 if (!array_th_len.empty()) {
475 PrependToBuffer(dest,
476 fmt::format(
"static const int th_len[] = {{\n" 478 "}};\n",
"array_th_len"_a = array_th_len), 0);
480 CHECK_EQ(node->children.size(), 1);
481 WalkAST<ThresholdType, LeafOutputType>(node->children[0], dest, indent);
484 template <
typename ThresholdType,
typename LeafOutputType>
486 const std::string& dest,
488 CHECK_EQ(node->children.size(), 1);
489 const int node_id = node->children[0]->node_id;
490 const int tree_id = node->children[0]->tree_id;
493 std::string array_nodes, array_cat_bitmap, array_cat_begin;
495 const std::string node_array_name
496 = fmt::format(
"node_tree{}_node{}", tree_id, node_id);
500 const std::string cat_bitmap_name
501 = fmt::format(
"cat_bitmap_tree{}_node{}", tree_id, node_id);
505 const std::string cat_begin_name
506 = fmt::format(
"cat_begin_tree{}_node{}", tree_id, node_id);
508 std::string output_switch_statement;
510 common_util::RenderCodeFolderArrays<ThresholdType, LeafOutputType>(node, param.
quantize,
false,
511 "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
513 &array_nodes, &array_cat_bitmap, &array_cat_begin,
514 &output_switch_statement, &common_comp_op);
515 if (!array_nodes.empty()) {
516 AppendToBuffer(
"header.h",
517 fmt::format(
"extern const struct Node {node_array_name}[];\n",
518 "node_array_name"_a = node_array_name), 0);
519 AppendToBuffer(
"arrays.c",
520 fmt::format(
"const struct Node {node_array_name}[] = {{\n" 523 "node_array_name"_a = node_array_name,
524 "array_nodes"_a = array_nodes), 0);
527 if (!array_cat_bitmap.empty()) {
528 AppendToBuffer(
"header.h",
529 fmt::format(
"extern const uint64_t {cat_bitmap_name}[];\n",
530 "cat_bitmap_name"_a = cat_bitmap_name), 0);
531 AppendToBuffer(
"arrays.c",
532 fmt::format(
"const uint64_t {cat_bitmap_name}[] = {{\n" 533 "{array_cat_bitmap}\n" 535 "cat_bitmap_name"_a = cat_bitmap_name,
536 "array_cat_bitmap"_a = array_cat_bitmap), 0);
539 if (!array_cat_begin.empty()) {
540 AppendToBuffer(
"header.h",
541 fmt::format(
"extern const size_t {cat_begin_name}[];\n",
542 "cat_begin_name"_a = cat_begin_name), 0);
543 AppendToBuffer(
"arrays.c",
544 fmt::format(
"const size_t {cat_begin_name}[] = {{\n" 545 "{array_cat_begin}\n" 547 "cat_begin_name"_a = cat_begin_name,
548 "array_cat_begin"_a = array_cat_begin), 0);
551 if (array_nodes.empty()) {
554 fmt::format(
"nid = -1;\n" 555 "{output_switch_statement}\n",
556 "output_switch_statement"_a
557 = output_switch_statement), indent);
558 }
else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
560 fmt::format(native::eval_loop_template,
561 "node_array_name"_a = node_array_name,
562 "cat_bitmap_name"_a = cat_bitmap_name,
563 "cat_begin_name"_a = cat_begin_name,
564 "data_field"_a = (param.
quantize > 0 ?
"qvalue" :
"fvalue"),
565 "comp_op"_a =
OpName(common_comp_op),
566 "output_switch_statement"_a
567 = output_switch_statement), indent);
570 fmt::format(native::eval_loop_template_without_categorical_feature,
571 "node_array_name"_a = node_array_name,
572 "data_field"_a = (param.
quantize > 0 ?
"qvalue" :
"fvalue"),
573 "comp_op"_a =
OpName(common_comp_op),
574 "output_switch_statement"_a
575 = output_switch_statement), indent);
579 template <
typename ThresholdType>
582 const std::string threshold_type
583 = native::TypeInfoToCTypeString(TypeToInfo<ThresholdType>());
585 if (node->quantized) {
586 std::string lhs = fmt::format(
"data[{split_index}].qvalue",
587 "split_index"_a = node->split_index);
588 result = fmt::format(
"{lhs} {opname} {threshold}",
590 "opname"_a =
OpName(node->op),
591 "threshold"_a = node->threshold.int_val);
592 }
else if (std::isinf(node->threshold.float_val)) {
595 result = (
CompareWithOp(static_cast<ThresholdType>(0), node->op, node->threshold.float_val)
598 std::string lhs = fmt::format(
"data[{split_index}].fvalue",
599 "split_index"_a = node->split_index);
601 = fmt::format(
"{lhs} {opname} ({threshold_type}){threshold}",
603 "opname"_a =
OpName(node->op),
604 "threshold_type"_a = threshold_type,
605 "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
613 std::vector<uint64_t> bitmap
614 = common_util::GetCategoricalBitmap(node->matching_categories);
615 CHECK_GE(bitmap.size(), 1);
616 bool all_zeros =
true;
617 for (uint64_t e : bitmap) {
618 all_zeros &= (e == 0);
623 std::ostringstream oss;
624 const std::string right_categories_flag = (node->categories_list_right_child ?
"!" :
"");
625 if (node->default_left) {
627 "data[{split_index}].missing == -1 || {right_categories_flag}(" 628 "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
629 "split_index"_a = node->split_index,
630 "right_categories_flag"_a = right_categories_flag);
633 "data[{split_index}].missing != -1 && {right_categories_flag}(" 634 "(tmp = (unsigned int)(data[{split_index}].fvalue) ), ",
635 "split_index"_a = node->split_index,
636 "right_categories_flag"_a = right_categories_flag);
638 oss <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 639 << bitmap[0] <<
"U >> tmp) & 1) )";
640 for (
size_t i = 1; i < bitmap.size(); ++i) {
641 oss <<
" || (tmp >= " << (i * 64)
642 <<
" && tmp < " << ((i + 1) * 64)
643 <<
" && (( (uint64_t)" << bitmap[i]
644 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
653 RenderIsCategoricalArray(
const std::vector<bool>& is_categorical) {
655 for (
int fid = 0; fid < num_feature_; ++fid) {
656 formatter << (is_categorical[fid] ? 1 : 0);
658 return formatter.
str();
661 template <
typename LeafOutputType>
663 const std::string leaf_output_type
664 = native::TypeInfoToCTypeString(TypeToInfo<LeafOutputType>());
665 std::string output_statement;
667 if (node->is_vector) {
669 CHECK_EQ(node->vector.size(),
static_cast<size_t>(task_param_.
num_class))
670 <<
"Ill-formed model: leaf vector must be of length [num_class]";
671 for (
size_t group_id = 0; group_id < task_param_.
num_class; ++group_id) {
673 += fmt::format(
"sum[{group_id}] += ({leaf_output_type}){output};\n",
674 "group_id"_a = group_id,
675 "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]),
676 "leaf_output_type"_a = leaf_output_type);
681 = fmt::format(
"sum[{group_id}] += ({leaf_output_type}){output};\n",
682 "group_id"_a = node->tree_id % task_param_.
num_class,
683 "output"_a = common_util::ToStringHighPrecision(node->scalar),
684 "leaf_output_type"_a = leaf_output_type);
688 = fmt::format(
"sum += ({leaf_output_type}){output};\n",
689 "output"_a = common_util::ToStringHighPrecision(node->scalar),
690 "leaf_output_type"_a = leaf_output_type);
692 return output_statement;
697 .describe(
"AST-based compiler that produces C code")
ModelParam param
extra parameters
Parameters for tree compiler.
std::string OpName(Operator op)
get string representation of comparison operator
TaskType
Enum type representing the task type.
parameters for tree compiler
bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
TaskType task_type
Task type.
Interface of compiler that compiles a tree ensemble model.
template for main function
Group of parameters that are dependent on the choice of the task type.
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Look up C symbols corresponding to TypeInfo.
Utilities for code folding.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Function to generate bitmaps for categorical splits.
TaskParameter task_param
Group of parameters that are specific to the particular task type.
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
unsigned int num_class
The number of classes in the target label.
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
thin wrapper for tree ensemble model
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Operator
comparison operators