4 #include <unordered_map> 7 #include "./pred_transform.h" 8 #include "./ast/builder.h" 9 #include "./native/get_num_feature.h" 10 #include "./native/get_num_output_group.h" 15 DMLC_REGISTRY_FILE_TAG(ast_native);
22 LOG(INFO) <<
"Using ASTNativeCompiler";
28 cm.backend =
"native";
29 cm.files[
"main.c"] =
"";
32 pred_tranform_func_ = PredTransformFunction(
"native", model);
39 std::unique_ptr<dmlc::Stream> fi(
40 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
41 annotator.
Load(fi.get());
42 const auto annotation = annotator.
Get();
43 builder.AnnotateBranches(annotation);
44 LOG(INFO) <<
"Using branch annotation file `" 49 builder.QuantizeThresholds();
51 #include "./native/header.h" 52 files_[
"header.h"] = header;
53 WalkAST(builder.GetRootNode(),
"main.c", 0);
57 std::vector<std::unordered_map<std::string, std::string>> source_list;
58 for (
auto kv : files_) {
59 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
60 const size_t line_count
61 = std::count(kv.second.begin(), kv.second.end(),
'\n');
62 source_list.push_back({ {
"name",
63 kv.first.substr(0, kv.first.length() - 2)},
64 {
"length", std::to_string(line_count)} });
67 std::ostringstream oss;
68 auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
69 writer->BeginObject();
70 writer->WriteObjectKeyValue(
"target", std::string(
"predictor"));
71 writer->WriteObjectKeyValue(
"sources", source_list);
73 files_[
"recipe.json"] = oss.str();
75 cm.files = std::move(files_);
80 int num_output_group_;
81 std::string pred_tranform_func_;
82 std::unordered_map<std::string, std::string> files_;
84 void WalkAST(
const ASTNode* node,
85 const std::string& dest,
93 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
94 HandleMainNode(t1, dest, indent);
95 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
96 HandleACNode(t2, dest, indent);
97 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
98 HandleCondNode(t3, dest, indent);
99 }
else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
100 HandleOutputNode(t4, dest, indent);
101 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
102 HandleTUNode(t5, dest, indent);
103 }
else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
104 HandleQNode(t6, dest, indent);
106 LOG(FATAL) <<
"oops";
110 void HandleMainNode(
const MainNode* node,
111 const std::string& dest,
113 const std::string prototype
114 = (num_output_group_ > 1) ?
115 "size_t predict_multiclass(union Entry* data, int pred_margin, " 117 :
"float predict(union Entry* data, int pred_margin)";
118 files_[dest] += std::string(indent,
' ') +
"#include \"header.h\"\n\n";
119 files_[dest] += get_num_output_group_func(num_output_group_) +
"\n" 120 + get_num_feature_func(node->num_feature) +
"\n" 121 + pred_tranform_func_ +
"\n" 122 + std::string(indent,
' ') + prototype +
" {\n";
123 files_[
"header.h"] += get_num_output_group_func_prototype();
124 files_[
"header.h"] += get_num_feature_func_prototype();
125 files_[
"header.h"] += prototype +
";\n";
126 CHECK_EQ(node->children.size(), 1);
127 WalkAST(node->children[0], dest, indent + 2);
128 std::ostringstream oss;
129 if (num_output_group_ > 1) {
130 oss << std::string(indent + 2,
' ')
131 <<
"for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 132 << std::string(indent + 4,
' ') <<
"result[i] = sum[i]";
133 if (node->average_result) {
134 oss <<
" / " << node->num_tree;
136 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 137 << std::string(indent + 2,
' ') <<
"}\n" 138 << std::string(indent + 2,
' ') <<
"if (!pred_margin) {\n" 139 << std::string(indent + 2,
' ')
140 <<
" return pred_transform(result);\n" 141 << std::string(indent + 2,
' ') <<
"} else {\n" 142 << std::string(indent + 2,
' ')
143 <<
" return " << num_output_group_ <<
";\n" 144 << std::string(indent + 2,
' ') <<
"}\n";
146 oss << std::string(indent + 2,
' ') <<
"sum = sum";
147 if (node->average_result) {
148 oss <<
" / " << node->num_tree;
150 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 151 << std::string(indent + 2,
' ') <<
"if (!pred_margin) {\n" 152 << std::string(indent + 2,
' ') <<
" return pred_transform(sum);\n" 153 << std::string(indent + 2,
' ') <<
"} else {\n" 154 << std::string(indent + 2,
' ') <<
" return sum;\n" 155 << std::string(indent + 2,
' ') <<
"}\n";
157 oss << std::string(indent,
' ') <<
"}\n";
158 files_[dest] += oss.str();
162 const std::string& dest,
164 std::ostringstream oss;
165 if (num_output_group_ > 1) {
166 oss << std::string(indent,
' ')
167 <<
"float sum[" << num_output_group_ <<
"] = {0.0f};\n";
169 oss << std::string(indent,
' ') <<
"float sum = 0.0f;\n";
171 oss << std::string(indent,
' ') <<
"unsigned int tmp;\n";
172 files_[dest] += oss.str();
173 for (
ASTNode* child : node->children) {
174 WalkAST(child, dest, indent);
179 const std::string& dest,
181 const unsigned split_index = node->split_index;
182 const std::string na_check
183 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
186 std::ostringstream oss;
187 if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
189 oss <<
"data[" << split_index <<
"].qvalue " 190 << OpName(t->op) <<
" " << t->threshold.int_val;
191 }
else if (std::isinf(t->threshold.float_val)) {
194 oss << (common::CompareWithOp(0.0, t->op, t->threshold.float_val)
198 const std::streamsize ss = std::cout.precision();
199 oss <<
"data[" << split_index <<
"].fvalue " 200 << OpName(t->op) <<
" " 201 << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
202 << t->threshold.float_val << std::setprecision(ss);
208 std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
209 CHECK_GE(bitmap.size(), 1);
210 bool all_zeros =
true;
211 for (uint64_t e : bitmap) {
212 all_zeros &= (e == 0);
217 oss <<
"(tmp = (unsigned int)(data[" << split_index <<
"].fvalue) ), " 218 <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 219 << bitmap[0] <<
"U >> tmp) & 1) )";
220 for (
size_t i = 1; i < bitmap.size(); ++i) {
221 oss <<
" || (tmp >= " << (i * 64)
222 <<
" && tmp < " << ((i + 1) * 64)
223 <<
" && (( (uint64_t)" << bitmap[i]
224 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
229 = ((node->default_left) ? (std::string(
"!(") + na_check +
") || (")
230 : (std::string(
" (") + na_check +
") && ("))
232 switch (node->branch_hint) {
233 case BranchHint::kLikely:
234 cond = std::string(
" LIKELY( ") + cond +
" ) ";
236 case BranchHint::kUnlikely:
237 cond = std::string(
" UNLIKELY( ") + cond +
" ) ";
240 files_[dest] += std::string(indent,
' ') +
"if (" + cond +
") {\n";
241 CHECK_EQ(node->children.size(), 2);
242 WalkAST(node->children[0], dest, indent + 2);
243 files_[dest] += std::string(indent,
' ') +
"} else {\n";
244 WalkAST(node->children[1], dest, indent + 2);
245 files_[dest] += std::string(indent,
' ') +
"}\n";
249 const std::string& dest,
251 std::ostringstream oss;
252 if (num_output_group_ > 1) {
253 if (node->is_vector) {
255 const std::vector<tl_float>& leaf_vector = node->vector;
256 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group_))
257 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
258 for (
int group_id = 0; group_id < num_output_group_; ++group_id) {
259 oss << std::string(indent,
' ')
260 <<
"sum[" << group_id <<
"] += (float)" 261 << common::ToString(leaf_vector[group_id]) <<
";\n";
265 oss << std::string(indent,
' ') <<
"sum[" 266 << (node->tree_id % num_output_group_) <<
"] += (float)" 267 << common::ToString(node->scalar) <<
";\n";
270 oss << std::string(indent,
' ') <<
"sum += (float)" 271 << common::ToString(node->scalar) <<
";\n";
273 files_[dest] += oss.str();
274 CHECK_EQ(node->children.size(), 0);
278 const std::string& dest,
280 const std::string new_file
281 = std::string(
"tu") + std::to_string(node->unit_id) +
".c";
282 std::ostringstream caller_buf, callee_buf, func_name, prototype;
283 callee_buf <<
"#include \"header.h\"\n";
284 if (num_output_group_ > 1) {
285 func_name <<
"predict_margin_multiclass_unit" << node->unit_id;
286 caller_buf << std::string(indent,
' ')
287 << func_name.str() <<
"(data, sum);\n";
288 prototype <<
"void " << func_name.str()
289 <<
"(union Entry* data, float* result)";
291 func_name <<
"predict_margin_unit" << node->unit_id;
292 caller_buf << std::string(indent,
' ')
293 <<
"sum += " << func_name.str() <<
"(data);\n";
294 prototype <<
"float " << func_name.str() <<
"(union Entry* data)";
296 callee_buf << prototype.str() <<
" {\n";
297 files_[dest] += caller_buf.str();
298 files_[new_file] += callee_buf.str();
299 CHECK_EQ(node->children.size(), 1);
300 WalkAST(node->children[0], new_file, 2);
301 callee_buf.str(
""); callee_buf.clear();
302 if (num_output_group_ > 1) {
303 callee_buf <<
" for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 304 <<
" result[i] = sum[i];\n" 307 callee_buf <<
" return sum;\n";
310 files_[new_file] += callee_buf.str();
311 files_[
"header.h"] += prototype.str() +
";\n";
315 const std::string& dest,
317 std::ostringstream oss;
318 const int num_feature = node->is_categorical.size();
320 oss <<
"static const unsigned char is_categorical[] = {\n ";
321 for (
int fid = 0; fid < num_feature; ++fid) {
322 if (node->is_categorical[fid]) {
323 common::WrapText(&oss, &length,
"1, ", 2, 80);
325 common::WrapText(&oss, &length,
"0, ", 2, 80);
330 oss <<
"static const float threshold[] = {\n ";
331 for (
const auto& e : node->cut_pts) {
333 common::WrapText(&oss, &length, common::ToString(v) +
", ", 2, 80);
339 oss <<
"static const int th_begin[] = {\n ";
340 for (
const auto& e : node->cut_pts) {
341 common::WrapText(&oss, &length, std::to_string(accum) +
", ", 2, 80);
346 oss <<
"static const int th_len[] = {\n ";
347 for (
const auto& e : node->cut_pts) {
348 common::WrapText(&oss, &length, std::to_string(e.size()) +
", ", 2, 80);
351 #include "./native/quantize_func.h" 352 oss << quantize_func << files_[dest] << std::string(indent,
' ')
353 <<
"for (int i = 0; i < " << num_feature <<
"; ++i) {\n" 354 << std::string(indent + 2,
' ')
355 <<
"if (data[i].missing != -1 && !is_categorical[i]) {\n" 356 << std::string(indent + 4,
' ')
357 <<
"data[i].qvalue = quantize(data[i].fvalue, i);\n" 358 << std::string(indent + 2,
' ') +
"}\n" 359 << std::string(indent,
' ') +
"}\n";
360 files_[dest] = oss.str();
361 CHECK_EQ(node->children.size(), 1);
362 WalkAST(node->children[0], dest, indent);
365 inline std::vector<uint64_t>
366 to_bitmap(
const std::vector<uint32_t>& left_categories)
const {
367 const size_t num_left_categories = left_categories.size();
368 if (num_left_categories == 0) {
369 return std::vector<uint64_t>{0};
371 const uint32_t max_left_category = left_categories[num_left_categories - 1];
372 std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
373 for (
size_t i = 0; i < num_left_categories; ++i) {
374 const uint32_t cat = left_categories[i];
375 const size_t idx = cat / 64;
376 const uint32_t offset = cat % 64;
377 bitmap[idx] |= (
static_cast<uint64_t
>(1) << offset);
384 .describe(
"AST-based compiler that produces C code")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
parameters for tree compiler
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages