9 #include <unordered_map> 13 #include <treelite/common.h> 15 #include <fmt/format.h> 17 #include "./pred_transform.h" 26 #if defined(_MSC_VER) || defined(_WIN32) 27 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 29 #define DLLEXPORT_KEYWORD "" 37 DMLC_REGISTRY_FILE_TAG(ast_native);
44 LOG(INFO) <<
"Using ASTNativeCompiler";
47 LOG(INFO) <<
"Warning: 'dump_array_as_elf' parameter is not applicable " 48 "for ASTNativeCompiler";
54 cm.backend =
"native";
61 pred_tranform_func_ = PredTransformFunction(
"native", model);
65 builder.BuildAST(model);
70 = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
74 std::unique_ptr<dmlc::Stream> fi(
75 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
76 annotator.
Load(fi.get());
77 const auto annotation = annotator.
Get();
78 builder.LoadDataCounts(annotation);
79 LOG(INFO) <<
"Loading node frequencies from `" 84 builder.QuantizeThresholds();
88 const char* destfile = getenv(
"TREELITE_DUMP_AST");
90 std::ofstream os(destfile);
91 os << builder.GetDump() << std::endl;
95 WalkAST(builder.GetRootNode(),
"main.c", 0);
96 if (files_.count(
"arrays.c") > 0) {
97 PrependToBuffer(
"arrays.c",
"#include \"header.h\"\n", 0);
102 std::vector<std::unordered_map<std::string, std::string>> source_list;
103 for (
const auto& kv : files_) {
104 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
105 const size_t line_count
106 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
107 source_list.push_back({ {
"name",
108 kv.first.substr(0, kv.first.length() - 2)},
109 {
"length", std::to_string(line_count)} });
112 std::ostringstream oss;
113 auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
114 writer->BeginObject();
116 writer->WriteObjectKeyValue(
"sources", source_list);
120 cm.files = std::move(files_);
127 int num_output_group_;
128 std::string pred_transform_;
129 float sigmoid_alpha_;
131 std::string pred_tranform_func_;
132 std::string array_is_categorical_;
133 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
135 void WalkAST(
const ASTNode* node,
136 const std::string& dest,
145 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
146 HandleMainNode(t1, dest, indent);
147 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
148 HandleACNode(t2, dest, indent);
149 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
150 HandleCondNode(t3, dest, indent);
151 }
else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
152 HandleOutputNode(t4, dest, indent);
153 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
154 HandleTUNode(t5, dest, indent);
155 }
else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
156 HandleQNode(t6, dest, indent);
157 }
else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
158 HandleCodeFolderNode(t7, dest, indent);
160 LOG(FATAL) <<
"Unrecognized AST node type";
165 inline void AppendToBuffer(
const std::string& dest,
166 const std::string& content,
168 files_[dest].content += common::IndentMultiLineString(content, indent);
172 inline void PrependToBuffer(
const std::string& dest,
173 const std::string& content,
175 files_[dest].content = common::IndentMultiLineString(content, indent) + files_[dest].content;
178 void HandleMainNode(
const MainNode* node,
179 const std::string& dest,
181 const char* get_num_output_group_function_signature
182 =
"size_t get_num_output_group(void)";
183 const char* get_num_feature_function_signature
184 =
"size_t get_num_feature(void)";
185 const char* get_pred_transform_function_signature
186 =
"const char* get_pred_transform(void)";
187 const char* get_sigmoid_alpha_function_signature
188 =
"float get_sigmoid_alpha(void)";
189 const char* get_global_bias_function_signature
190 =
"float get_global_bias(void)";
191 const char* predict_function_signature
192 = (num_output_group_ > 1) ?
193 "size_t predict_multiclass(union Entry* data, int pred_margin, " 195 :
"float predict(union Entry* data, int pred_margin)";
197 if (!array_is_categorical_.empty()) {
198 array_is_categorical_
199 = fmt::format(
"const unsigned char is_categorical[] = {{\n{}\n}}",
200 array_is_categorical_);
204 fmt::format(native::main_start_template,
205 "array_is_categorical"_a = array_is_categorical_,
206 "get_num_output_group_function_signature"_a
207 = get_num_output_group_function_signature,
208 "get_num_feature_function_signature"_a
209 = get_num_feature_function_signature,
210 "get_pred_transform_function_signature"_a
211 = get_pred_transform_function_signature,
212 "get_sigmoid_alpha_function_signature"_a
213 = get_sigmoid_alpha_function_signature,
214 "get_global_bias_function_signature"_a
215 = get_global_bias_function_signature,
216 "pred_transform_function"_a = pred_tranform_func_,
217 "predict_function_signature"_a = predict_function_signature,
218 "num_output_group"_a = num_output_group_,
219 "num_feature"_a = num_feature_,
220 "pred_transform"_a = pred_transform_,
221 "sigmoid_alpha"_a = sigmoid_alpha_,
222 "global_bias"_a = global_bias_),
224 AppendToBuffer(
"header.h",
225 fmt::format(native::header_template,
226 "dllexport"_a = DLLEXPORT_KEYWORD,
227 "get_num_output_group_function_signature"_a
228 = get_num_output_group_function_signature,
229 "get_num_feature_function_signature"_a
230 = get_num_feature_function_signature,
231 "get_pred_transform_function_signature"_a
232 = get_pred_transform_function_signature,
233 "get_sigmoid_alpha_function_signature"_a
234 = get_sigmoid_alpha_function_signature,
235 "get_global_bias_function_signature"_a
236 = get_global_bias_function_signature,
237 "predict_function_signature"_a = predict_function_signature,
238 "threshold_type"_a = (param.
quantize > 0 ?
"int" :
"double")),
241 CHECK_EQ(node->children.size(), 1);
242 WalkAST(node->children[0], dest, indent + 2);
244 const std::string optional_average_field
245 = (node->average_result) ? fmt::format(
" / {}", node->num_tree)
247 if (num_output_group_ > 1) {
249 fmt::format(native::main_end_multiclass_template,
250 "num_output_group"_a = num_output_group_,
251 "optional_average_field"_a = optional_average_field,
252 "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
256 fmt::format(native::main_end_template,
257 "optional_average_field"_a = optional_average_field,
258 "global_bias"_a = common::ToStringHighPrecision(node->global_bias)),
264 const std::string& dest,
266 if (num_output_group_ > 1) {
268 fmt::format(
"float sum[{num_output_group}] = {{0.0f}};\n" 269 "unsigned int tmp;\n" 270 "int nid, cond, fid; /* used for folded subtrees */\n",
271 "num_output_group"_a = num_output_group_), indent);
274 "float sum = 0.0f;\n" 275 "unsigned int tmp;\n" 276 "int nid, cond, fid; /* used for folded subtrees */\n", indent);
278 for (
ASTNode* child : node->children) {
279 WalkAST(child, dest, indent);
284 const std::string& dest,
287 std::string condition, condition_with_na_check;
288 if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
290 condition = ExtractNumericalCondition(t);
291 const char* condition_with_na_check_template
292 = (node->default_left) ?
293 "!(data[{split_index}].missing != -1) || ({condition})" 294 :
" (data[{split_index}].missing != -1) && ({condition})";
295 condition_with_na_check
296 = fmt::format(condition_with_na_check_template,
297 "split_index"_a = node->split_index,
298 "condition"_a = condition);
303 condition_with_na_check = ExtractCategoricalCondition(t2);
305 if (node->children[0]->data_count && node->children[1]->data_count) {
306 const int left_freq = node->children[0]->data_count.value();
307 const int right_freq = node->children[1]->data_count.value();
308 condition_with_na_check
309 = fmt::format(
" {keyword}( {condition} ) ",
310 "keyword"_a = ((left_freq > right_freq) ?
"LIKELY" :
"UNLIKELY"),
311 "condition"_a = condition_with_na_check);
314 fmt::format(
"if ({}) {{\n", condition_with_na_check), indent);
315 CHECK_EQ(node->children.size(), 2);
316 WalkAST(node->children[0], dest, indent + 2);
317 AppendToBuffer(dest,
"} else {\n", indent);
318 WalkAST(node->children[1], dest, indent + 2);
319 AppendToBuffer(dest,
"}\n", indent);
323 const std::string& dest,
325 AppendToBuffer(dest, RenderOutputStatement(node), indent);
326 CHECK_EQ(node->children.size(), 0);
330 const std::string& dest,
332 const int unit_id = node->unit_id;
333 const std::string new_file = fmt::format(
"tu{}.c", unit_id);
335 std::string unit_function_name, unit_function_signature,
336 unit_function_call_signature;
337 if (num_output_group_ > 1) {
339 = fmt::format(
"predict_margin_multiclass_unit{}", unit_id);
340 unit_function_signature
341 = fmt::format(
"void {}(union Entry* data, float* result)",
343 unit_function_call_signature
344 = fmt::format(
"{}(data, sum);\n", unit_function_name);
347 = fmt::format(
"predict_margin_unit{}", unit_id);
348 unit_function_signature
349 = fmt::format(
"float {}(union Entry* data)", unit_function_name);
350 unit_function_call_signature
351 = fmt::format(
"sum += {}(data);\n", unit_function_name);
353 AppendToBuffer(dest, unit_function_call_signature, indent);
354 AppendToBuffer(new_file,
355 fmt::format(
"#include \"header.h\"\n" 356 "{} {{\n", unit_function_signature), 0);
357 CHECK_EQ(node->children.size(), 1);
358 WalkAST(node->children[0], new_file, 2);
359 if (num_output_group_ > 1) {
360 AppendToBuffer(new_file,
361 fmt::format(
" for (int i = 0; i < {num_output_group}; ++i) {{\n" 362 " result[i] += sum[i];\n" 365 "num_output_group"_a = num_output_group_), 0);
367 AppendToBuffer(new_file,
" return sum;\n}\n", 0);
369 AppendToBuffer(
"header.h", fmt::format(
"{};\n", unit_function_signature), 0);
373 const std::string& dest,
376 std::string array_threshold, array_th_begin, array_th_len;
381 size_t total_num_threshold;
385 for (
const auto& e : node->cut_pts) {
392 array_threshold = formatter.
str();
397 for (
const auto& e : node->cut_pts) {
401 total_num_threshold = accum;
402 array_th_begin = formatter.
str();
406 for (
const auto& e : node->cut_pts) {
407 formatter << e.size();
409 array_th_len = formatter.
str();
411 if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
412 PrependToBuffer(dest,
413 fmt::format(native::qnode_template,
414 "total_num_threshold"_a = total_num_threshold), 0);
416 fmt::format(native::quantize_loop_template,
417 "num_feature"_a = num_feature_), indent);
419 if (!array_threshold.empty()) {
420 PrependToBuffer(dest,
421 fmt::format(
"static const double threshold[] = {{\n" 422 "{array_threshold}\n" 423 "}};\n",
"array_threshold"_a = array_threshold), 0);
425 if (!array_th_begin.empty()) {
426 PrependToBuffer(dest,
427 fmt::format(
"static const int th_begin[] = {{\n" 429 "}};\n",
"array_th_begin"_a = array_th_begin), 0);
431 if (!array_th_len.empty()) {
432 PrependToBuffer(dest,
433 fmt::format(
"static const int th_len[] = {{\n" 435 "}};\n",
"array_th_len"_a = array_th_len), 0);
437 CHECK_EQ(node->children.size(), 1);
438 WalkAST(node->children[0], dest, indent);
442 const std::string& dest,
444 CHECK_EQ(node->children.size(), 1);
445 const int node_id = node->children[0]->node_id;
446 const int tree_id = node->children[0]->tree_id;
449 std::string array_nodes, array_cat_bitmap, array_cat_begin;
451 const std::string node_array_name
452 = fmt::format(
"node_tree{}_node{}", tree_id, node_id);
456 const std::string cat_bitmap_name
457 = fmt::format(
"cat_bitmap_tree{}_node{}", tree_id, node_id);
461 const std::string cat_begin_name
462 = fmt::format(
"cat_begin_tree{}_node{}", tree_id, node_id);
464 std::string output_switch_statement;
466 common_util::RenderCodeFolderArrays(node, param.
quantize,
false,
467 "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
468 [
this](
const OutputNode* node) { return RenderOutputStatement(node); },
469 &array_nodes, &array_cat_bitmap, &array_cat_begin,
470 &output_switch_statement, &common_comp_op);
471 if (!array_nodes.empty()) {
472 AppendToBuffer(
"header.h",
473 fmt::format(
"extern const struct Node {node_array_name}[];\n",
474 "node_array_name"_a = node_array_name), 0);
475 AppendToBuffer(
"arrays.c",
476 fmt::format(
"const struct Node {node_array_name}[] = {{\n" 479 "node_array_name"_a = node_array_name,
480 "array_nodes"_a = array_nodes), 0);
483 if (!array_cat_bitmap.empty()) {
484 AppendToBuffer(
"header.h",
485 fmt::format(
"extern const uint64_t {cat_bitmap_name}[];\n",
486 "cat_bitmap_name"_a = cat_bitmap_name), 0);
487 AppendToBuffer(
"arrays.c",
488 fmt::format(
"const uint64_t {cat_bitmap_name}[] = {{\n" 489 "{array_cat_bitmap}\n" 491 "cat_bitmap_name"_a = cat_bitmap_name,
492 "array_cat_bitmap"_a = array_cat_bitmap), 0);
495 if (!array_cat_begin.empty()) {
496 AppendToBuffer(
"header.h",
497 fmt::format(
"extern const size_t {cat_begin_name}[];\n",
498 "cat_begin_name"_a = cat_begin_name), 0);
499 AppendToBuffer(
"arrays.c",
500 fmt::format(
"const size_t {cat_begin_name}[] = {{\n" 501 "{array_cat_begin}\n" 503 "cat_begin_name"_a = cat_begin_name,
504 "array_cat_begin"_a = array_cat_begin), 0);
507 if (array_nodes.empty()) {
510 fmt::format(
"nid = -1;\n" 511 "{output_switch_statement}\n",
512 "output_switch_statement"_a
513 = output_switch_statement), indent);
514 }
else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
516 fmt::format(native::eval_loop_template,
517 "node_array_name"_a = node_array_name,
518 "cat_bitmap_name"_a = cat_bitmap_name,
519 "cat_begin_name"_a = cat_begin_name,
520 "data_field"_a = (param.
quantize > 0 ?
"qvalue" :
"fvalue"),
521 "comp_op"_a =
OpName(common_comp_op),
522 "output_switch_statement"_a
523 = output_switch_statement), indent);
526 fmt::format(native::eval_loop_template_without_categorical_feature,
527 "node_array_name"_a = node_array_name,
528 "data_field"_a = (param.
quantize > 0 ?
"qvalue" :
"fvalue"),
529 "comp_op"_a =
OpName(common_comp_op),
530 "output_switch_statement"_a
531 = output_switch_statement), indent);
538 if (node->quantized) {
539 result = fmt::format(
"data[{split_index}].qvalue {opname} {threshold}",
540 "split_index"_a = node->split_index,
541 "opname"_a =
OpName(node->op),
542 "threshold"_a = node->threshold.int_val);
543 }
else if (std::isinf(node->threshold.float_val)) {
546 result = (common::CompareWithOp(0.0, node->op, node->threshold.float_val)
549 result = fmt::format(
"data[{split_index}].fvalue {opname} {threshold}",
550 "split_index"_a = node->split_index,
551 "opname"_a =
OpName(node->op),
553 = common::ToStringHighPrecision(node->threshold.float_val));
561 std::vector<uint64_t> bitmap
562 = common_util::GetCategoricalBitmap(node->left_categories);
563 CHECK_GE(bitmap.size(), 1);
564 bool all_zeros =
true;
565 for (uint64_t e : bitmap) {
566 all_zeros &= (e == 0);
571 std::ostringstream oss;
572 if (node->convert_missing_to_zero) {
575 "((tmp = (data[{0}].missing == -1 ? 0U " 576 ": (unsigned int)(data[{0}].fvalue) )), ", node->split_index);
578 if (node->default_left) {
580 "data[{0}].missing == -1 || (" 581 "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
584 "data[{0}].missing != -1 && (" 585 "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
588 oss <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 589 << bitmap[0] <<
"U >> tmp) & 1) )";
590 for (
size_t i = 1; i < bitmap.size(); ++i) {
591 oss <<
" || (tmp >= " << (i * 64)
592 <<
" && tmp < " << ((i + 1) * 64)
593 <<
" && (( (uint64_t)" << bitmap[i]
594 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
603 RenderIsCategoricalArray(
const std::vector<bool>& is_categorical) {
605 for (
int fid = 0; fid < num_feature_; ++fid) {
606 formatter << (is_categorical[fid] ? 1 : 0);
608 return formatter.
str();
611 inline std::string RenderOutputStatement(
const OutputNode* node) {
612 std::string output_statement;
613 if (num_output_group_ > 1) {
614 if (node->is_vector) {
616 CHECK_EQ(node->vector.size(),
static_cast<size_t>(num_output_group_))
617 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
618 for (
int group_id = 0; group_id < num_output_group_; ++group_id) {
620 += fmt::format(
"sum[{group_id}] += (float){output};\n",
621 "group_id"_a = group_id,
623 = common::ToStringHighPrecision(node->vector[group_id]));
628 = fmt::format(
"sum[{group_id}] += (float){output};\n",
629 "group_id"_a = node->tree_id % num_output_group_,
630 "output"_a = common::ToStringHighPrecision(node->scalar));
634 = fmt::format(
"sum += (float){output};\n",
635 "output"_a = common::ToStringHighPrecision(node->scalar));
637 return output_statement;
642 .describe(
"AST-based compiler that produces C code")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
thin wrapper for tree ensemble model
std::string OpName(Operator op)
get string representation of comparsion operator
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
ModelParam param
extra parameters
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
float global_bias
global bias of the model
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
template for main function
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
std::string pred_transform
name of prediction transform function
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Utilities for code folding.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Function to generate bitmaps for categorical splits.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
double tl_float
float type to be used internally
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators