10 #include <fmt/format.h> 13 #include <unordered_map> 16 #include "./pred_transform.h" 26 #if defined(_MSC_VER) || defined(_WIN32) 27 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 29 #define DLLEXPORT_KEYWORD "" 37 DMLC_REGISTRY_FILE_TAG(ast_native);
44 LOG(INFO) <<
"Using ASTNativeCompiler";
47 LOG(INFO) <<
"Warning: 'dump_array_as_elf' parameter is not applicable " 48 "for ASTNativeCompiler";
54 cm.backend =
"native";
61 pred_tranform_func_ = PredTransformFunction(
"native", model);
65 builder.BuildAST(model);
70 = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray());
74 std::unique_ptr<dmlc::Stream> fi(
75 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
76 annotator.
Load(fi.get());
77 const auto annotation = annotator.
Get();
78 builder.LoadDataCounts(annotation);
79 LOG(INFO) <<
"Loading node frequencies from `" 84 builder.QuantizeThresholds();
88 const char* destfile = getenv(
"TREELITE_DUMP_AST");
90 std::ofstream os(destfile);
91 os << builder.GetDump() << std::endl;
95 WalkAST(builder.GetRootNode(),
"main.c", 0);
96 if (files_.count(
"arrays.c") > 0) {
97 PrependToBuffer(
"arrays.c",
"#include \"header.h\"\n", 0);
102 std::vector<std::unordered_map<std::string, std::string>> source_list;
103 for (
const auto& kv : files_) {
104 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
105 const size_t line_count
106 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
107 source_list.push_back({ {
"name",
108 kv.first.substr(0, kv.first.length() - 2)},
109 {
"length", std::to_string(line_count)} });
112 std::ostringstream oss;
113 std::unique_ptr<dmlc::JSONWriter> writer(
new dmlc::JSONWriter(&oss));
114 writer->BeginObject();
116 writer->WriteObjectKeyValue(
"sources", source_list);
120 cm.files = std::move(files_);
127 int num_output_group_;
128 std::string pred_transform_;
129 float sigmoid_alpha_;
131 std::string pred_tranform_func_;
132 std::string array_is_categorical_;
133 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
135 void WalkAST(
const ASTNode* node,
136 const std::string& dest,
145 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
146 HandleMainNode(t1, dest, indent);
147 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
148 HandleACNode(t2, dest, indent);
149 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
150 HandleCondNode(t3, dest, indent);
151 }
else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
152 HandleOutputNode(t4, dest, indent);
153 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
154 HandleTUNode(t5, dest, indent);
155 }
else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
156 HandleQNode(t6, dest, indent);
157 }
else if ( (t7 = dynamic_cast<const CodeFolderNode*>(node)) ) {
158 HandleCodeFolderNode(t7, dest, indent);
160 LOG(FATAL) <<
"Unrecognized AST node type";
165 inline void AppendToBuffer(
const std::string& dest,
166 const std::string& content,
168 files_[dest].content += common_util::IndentMultiLineString(content, indent);
172 inline void PrependToBuffer(
const std::string& dest,
173 const std::string& content,
176 = common_util::IndentMultiLineString(content, indent) + files_[dest].content;
179 void HandleMainNode(
const MainNode* node,
180 const std::string& dest,
182 const char* get_num_output_group_function_signature
183 =
"size_t get_num_output_group(void)";
184 const char* get_num_feature_function_signature
185 =
"size_t get_num_feature(void)";
186 const char* get_pred_transform_function_signature
187 =
"const char* get_pred_transform(void)";
188 const char* get_sigmoid_alpha_function_signature
189 =
"float get_sigmoid_alpha(void)";
190 const char* get_global_bias_function_signature
191 =
"float get_global_bias(void)";
192 const char* predict_function_signature
193 = (num_output_group_ > 1) ?
194 "size_t predict_multiclass(union Entry* data, int pred_margin, " 196 :
"float predict(union Entry* data, int pred_margin)";
198 if (!array_is_categorical_.empty()) {
199 array_is_categorical_
200 = fmt::format(
"const unsigned char is_categorical[] = {{\n{}\n}}",
201 array_is_categorical_);
205 fmt::format(native::main_start_template,
206 "array_is_categorical"_a = array_is_categorical_,
207 "get_num_output_group_function_signature"_a
208 = get_num_output_group_function_signature,
209 "get_num_feature_function_signature"_a
210 = get_num_feature_function_signature,
211 "get_pred_transform_function_signature"_a
212 = get_pred_transform_function_signature,
213 "get_sigmoid_alpha_function_signature"_a
214 = get_sigmoid_alpha_function_signature,
215 "get_global_bias_function_signature"_a
216 = get_global_bias_function_signature,
217 "pred_transform_function"_a = pred_tranform_func_,
218 "predict_function_signature"_a = predict_function_signature,
219 "num_output_group"_a = num_output_group_,
220 "num_feature"_a = num_feature_,
221 "pred_transform"_a = pred_transform_,
222 "sigmoid_alpha"_a = sigmoid_alpha_,
223 "global_bias"_a = global_bias_),
225 AppendToBuffer(
"header.h",
226 fmt::format(native::header_template,
227 "dllexport"_a = DLLEXPORT_KEYWORD,
228 "get_num_output_group_function_signature"_a
229 = get_num_output_group_function_signature,
230 "get_num_feature_function_signature"_a
231 = get_num_feature_function_signature,
232 "get_pred_transform_function_signature"_a
233 = get_pred_transform_function_signature,
234 "get_sigmoid_alpha_function_signature"_a
235 = get_sigmoid_alpha_function_signature,
236 "get_global_bias_function_signature"_a
237 = get_global_bias_function_signature,
238 "predict_function_signature"_a = predict_function_signature,
239 "threshold_type"_a = (param.
quantize > 0 ?
"int" :
"float")),
242 CHECK_EQ(node->children.size(), 1);
243 WalkAST(node->children[0], dest, indent + 2);
245 const std::string optional_average_field
246 = (node->average_result) ? fmt::format(
" / {}", node->num_tree)
248 if (num_output_group_ > 1) {
250 fmt::format(native::main_end_multiclass_template,
251 "num_output_group"_a = num_output_group_,
252 "optional_average_field"_a = optional_average_field,
253 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias)),
257 fmt::format(native::main_end_template,
258 "optional_average_field"_a = optional_average_field,
259 "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias)),
265 const std::string& dest,
267 if (num_output_group_ > 1) {
269 fmt::format(
"float sum[{num_output_group}] = {{0.0f}};\n" 270 "unsigned int tmp;\n" 271 "int nid, cond, fid; /* used for folded subtrees */\n",
272 "num_output_group"_a = num_output_group_), indent);
275 "float sum = 0.0f;\n" 276 "unsigned int tmp;\n" 277 "int nid, cond, fid; /* used for folded subtrees */\n", indent);
279 for (
ASTNode* child : node->children) {
280 WalkAST(child, dest, indent);
285 const std::string& dest,
288 std::string condition, condition_with_na_check;
289 if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
291 condition = ExtractNumericalCondition(t);
292 const char* condition_with_na_check_template
293 = (node->default_left) ?
294 "!(data[{split_index}].missing != -1) || ({condition})" 295 :
" (data[{split_index}].missing != -1) && ({condition})";
296 condition_with_na_check
297 = fmt::format(condition_with_na_check_template,
298 "split_index"_a = node->split_index,
299 "condition"_a = condition);
304 condition_with_na_check = ExtractCategoricalCondition(t2);
306 if (node->children[0]->data_count && node->children[1]->data_count) {
307 const int left_freq = node->children[0]->data_count.value();
308 const int right_freq = node->children[1]->data_count.value();
309 condition_with_na_check
310 = fmt::format(
" {keyword}( {condition} ) ",
311 "keyword"_a = ((left_freq > right_freq) ?
"LIKELY" :
"UNLIKELY"),
312 "condition"_a = condition_with_na_check);
315 fmt::format(
"if ({}) {{\n", condition_with_na_check), indent);
316 CHECK_EQ(node->children.size(), 2);
317 WalkAST(node->children[0], dest, indent + 2);
318 AppendToBuffer(dest,
"} else {\n", indent);
319 WalkAST(node->children[1], dest, indent + 2);
320 AppendToBuffer(dest,
"}\n", indent);
324 const std::string& dest,
326 AppendToBuffer(dest, RenderOutputStatement(node), indent);
327 CHECK_EQ(node->children.size(), 0);
331 const std::string& dest,
333 const int unit_id = node->unit_id;
334 const std::string new_file = fmt::format(
"tu{}.c", unit_id);
336 std::string unit_function_name, unit_function_signature,
337 unit_function_call_signature;
338 if (num_output_group_ > 1) {
340 = fmt::format(
"predict_margin_multiclass_unit{}", unit_id);
341 unit_function_signature
342 = fmt::format(
"void {}(union Entry* data, float* result)",
344 unit_function_call_signature
345 = fmt::format(
"{}(data, sum);\n", unit_function_name);
348 = fmt::format(
"predict_margin_unit{}", unit_id);
349 unit_function_signature
350 = fmt::format(
"float {}(union Entry* data)", unit_function_name);
351 unit_function_call_signature
352 = fmt::format(
"sum += {}(data);\n", unit_function_name);
354 AppendToBuffer(dest, unit_function_call_signature, indent);
355 AppendToBuffer(new_file,
356 fmt::format(
"#include \"header.h\"\n" 357 "{} {{\n", unit_function_signature), 0);
358 CHECK_EQ(node->children.size(), 1);
359 WalkAST(node->children[0], new_file, 2);
360 if (num_output_group_ > 1) {
361 AppendToBuffer(new_file,
362 fmt::format(
" for (int i = 0; i < {num_output_group}; ++i) {{\n" 363 " result[i] += sum[i];\n" 366 "num_output_group"_a = num_output_group_), 0);
368 AppendToBuffer(new_file,
" return sum;\n}\n", 0);
370 AppendToBuffer(
"header.h", fmt::format(
"{};\n", unit_function_signature), 0);
374 const std::string& dest,
377 std::string array_threshold, array_th_begin, array_th_len;
382 size_t total_num_threshold;
386 for (
const auto& e : node->cut_pts) {
393 array_threshold = formatter.
str();
398 for (
const auto& e : node->cut_pts) {
402 total_num_threshold = accum;
403 array_th_begin = formatter.
str();
407 for (
const auto& e : node->cut_pts) {
408 formatter << e.size();
410 array_th_len = formatter.
str();
412 if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) {
413 PrependToBuffer(dest,
414 fmt::format(native::qnode_template,
415 "total_num_threshold"_a = total_num_threshold), 0);
417 fmt::format(native::quantize_loop_template,
418 "num_feature"_a = num_feature_), indent);
420 if (!array_threshold.empty()) {
421 PrependToBuffer(dest,
422 fmt::format(
"static const float threshold[] = {{\n" 423 "{array_threshold}\n" 424 "}};\n",
"array_threshold"_a = array_threshold), 0);
426 if (!array_th_begin.empty()) {
427 PrependToBuffer(dest,
428 fmt::format(
"static const int th_begin[] = {{\n" 430 "}};\n",
"array_th_begin"_a = array_th_begin), 0);
432 if (!array_th_len.empty()) {
433 PrependToBuffer(dest,
434 fmt::format(
"static const int th_len[] = {{\n" 436 "}};\n",
"array_th_len"_a = array_th_len), 0);
438 CHECK_EQ(node->children.size(), 1);
439 WalkAST(node->children[0], dest, indent);
443 const std::string& dest,
445 CHECK_EQ(node->children.size(), 1);
446 const int node_id = node->children[0]->node_id;
447 const int tree_id = node->children[0]->tree_id;
450 std::string array_nodes, array_cat_bitmap, array_cat_begin;
452 const std::string node_array_name
453 = fmt::format(
"node_tree{}_node{}", tree_id, node_id);
457 const std::string cat_bitmap_name
458 = fmt::format(
"cat_bitmap_tree{}_node{}", tree_id, node_id);
462 const std::string cat_begin_name
463 = fmt::format(
"cat_begin_tree{}_node{}", tree_id, node_id);
465 std::string output_switch_statement;
467 common_util::RenderCodeFolderArrays(node, param.
quantize,
false,
468 "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}",
469 [
this](
const OutputNode* node) { return RenderOutputStatement(node); },
470 &array_nodes, &array_cat_bitmap, &array_cat_begin,
471 &output_switch_statement, &common_comp_op);
472 if (!array_nodes.empty()) {
473 AppendToBuffer(
"header.h",
474 fmt::format(
"extern const struct Node {node_array_name}[];\n",
475 "node_array_name"_a = node_array_name), 0);
476 AppendToBuffer(
"arrays.c",
477 fmt::format(
"const struct Node {node_array_name}[] = {{\n" 480 "node_array_name"_a = node_array_name,
481 "array_nodes"_a = array_nodes), 0);
484 if (!array_cat_bitmap.empty()) {
485 AppendToBuffer(
"header.h",
486 fmt::format(
"extern const uint64_t {cat_bitmap_name}[];\n",
487 "cat_bitmap_name"_a = cat_bitmap_name), 0);
488 AppendToBuffer(
"arrays.c",
489 fmt::format(
"const uint64_t {cat_bitmap_name}[] = {{\n" 490 "{array_cat_bitmap}\n" 492 "cat_bitmap_name"_a = cat_bitmap_name,
493 "array_cat_bitmap"_a = array_cat_bitmap), 0);
496 if (!array_cat_begin.empty()) {
497 AppendToBuffer(
"header.h",
498 fmt::format(
"extern const size_t {cat_begin_name}[];\n",
499 "cat_begin_name"_a = cat_begin_name), 0);
500 AppendToBuffer(
"arrays.c",
501 fmt::format(
"const size_t {cat_begin_name}[] = {{\n" 502 "{array_cat_begin}\n" 504 "cat_begin_name"_a = cat_begin_name,
505 "array_cat_begin"_a = array_cat_begin), 0);
508 if (array_nodes.empty()) {
511 fmt::format(
"nid = -1;\n" 512 "{output_switch_statement}\n",
513 "output_switch_statement"_a
514 = output_switch_statement), indent);
515 }
else if (!array_cat_bitmap.empty() && !array_cat_begin.empty()) {
517 fmt::format(native::eval_loop_template,
518 "node_array_name"_a = node_array_name,
519 "cat_bitmap_name"_a = cat_bitmap_name,
520 "cat_begin_name"_a = cat_begin_name,
521 "data_field"_a = (param.
quantize > 0 ?
"qvalue" :
"fvalue"),
522 "comp_op"_a =
OpName(common_comp_op),
523 "output_switch_statement"_a
524 = output_switch_statement), indent);
527 fmt::format(native::eval_loop_template_without_categorical_feature,
528 "node_array_name"_a = node_array_name,
529 "data_field"_a = (param.
quantize > 0 ?
"qvalue" :
"fvalue"),
530 "comp_op"_a =
OpName(common_comp_op),
531 "output_switch_statement"_a
532 = output_switch_statement), indent);
539 if (node->quantized) {
540 result = fmt::format(
"data[{split_index}].qvalue {opname} {threshold}",
541 "split_index"_a = node->split_index,
542 "opname"_a =
OpName(node->op),
543 "threshold"_a = node->threshold.int_val);
544 }
else if (std::isinf(node->threshold.float_val)) {
547 result = (
CompareWithOp(0.0, node->op, node->threshold.float_val) ?
"1" :
"0");
549 result = fmt::format(
"data[{split_index}].fvalue {opname} (float){threshold}",
550 "split_index"_a = node->split_index,
551 "opname"_a =
OpName(node->op),
552 "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val));
560 std::vector<uint64_t> bitmap
561 = common_util::GetCategoricalBitmap(node->left_categories);
562 CHECK_GE(bitmap.size(), 1);
563 bool all_zeros =
true;
564 for (uint64_t e : bitmap) {
565 all_zeros &= (e == 0);
570 std::ostringstream oss;
571 if (node->convert_missing_to_zero) {
574 "((tmp = (data[{0}].missing == -1 ? 0U " 575 ": (unsigned int)(data[{0}].fvalue) )), ", node->split_index);
577 if (node->default_left) {
579 "data[{0}].missing == -1 || (" 580 "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
583 "data[{0}].missing != -1 && (" 584 "(tmp = (unsigned int)(data[{0}].fvalue) ), ", node->split_index);
587 oss <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 588 << bitmap[0] <<
"U >> tmp) & 1) )";
589 for (
size_t i = 1; i < bitmap.size(); ++i) {
590 oss <<
" || (tmp >= " << (i * 64)
591 <<
" && tmp < " << ((i + 1) * 64)
592 <<
" && (( (uint64_t)" << bitmap[i]
593 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
602 RenderIsCategoricalArray(
const std::vector<bool>& is_categorical) {
604 for (
int fid = 0; fid < num_feature_; ++fid) {
605 formatter << (is_categorical[fid] ? 1 : 0);
607 return formatter.
str();
610 inline std::string RenderOutputStatement(
const OutputNode* node) {
611 std::string output_statement;
612 if (num_output_group_ > 1) {
613 if (node->is_vector) {
615 CHECK_EQ(node->vector.size(),
static_cast<size_t>(num_output_group_))
616 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
617 for (
int group_id = 0; group_id < num_output_group_; ++group_id) {
619 += fmt::format(
"sum[{group_id}] += (float){output};\n",
620 "group_id"_a = group_id,
621 "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]));
626 = fmt::format(
"sum[{group_id}] += (float){output};\n",
627 "group_id"_a = node->tree_id % num_output_group_,
628 "output"_a = common_util::ToStringHighPrecision(node->scalar));
632 = fmt::format(
"sum += (float){output};\n",
633 "output"_a = common_util::ToStringHighPrecision(node->scalar));
635 return output_statement;
640 .describe(
"AST-based compiler that produces C code")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Parameters for tree compiler.
bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs)
perform comparison between two float's using a comparsion operator The comparison will be in the form...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::string OpName(Operator op)
get string representation of comparsion operator
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
ModelParam param
extra parameters
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
float global_bias
global bias of the model
Interface of compiler that compiles a tree ensemble model.
template for main function
template for evaluation logic for folded code
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
code template for QuantizerNode
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Utilities for code folding.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Function to generate bitmaps for categorical splits.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Operator
comparison operators