3 #include <unordered_map> 5 #include "./pred_transform.h" 6 #include "./ast/builder.h" 7 #include "./native/get_num_feature.h" 8 #include "./native/get_num_output_group.h" 13 DMLC_REGISTRY_FILE_TAG(ast_native);
20 LOG(INFO) <<
"Using ASTNativeCompiler";
26 cm.backend =
"native";
27 cm.files[
"main.c"] =
"";
30 pred_tranform_func_ = PredTransformFunction(
"native", model);
37 std::unique_ptr<dmlc::Stream> fi(
38 dmlc::Stream::Create(param.
annotate_in.c_str(),
"r"));
39 annotator.
Load(fi.get());
40 const auto annotation = annotator.
Get();
41 builder.AnnotateBranches(annotation);
42 LOG(INFO) <<
"Using branch annotation file `" 47 builder.QuantizeThresholds();
49 #include "./native/header.h" 50 files_[
"header.h"] = header;
51 WalkAST(builder.GetRootNode(),
"main.c", 0);
55 std::vector<std::unordered_map<std::string, std::string>> source_list;
56 for (
auto kv : files_) {
57 if (kv.first.compare(kv.first.length() - 2, 2,
".c") == 0) {
58 const size_t line_count
59 = std::count(kv.second.begin(), kv.second.end(),
'\n');
60 source_list.push_back({ {
"name",
61 kv.first.substr(0, kv.first.length() - 2)},
62 {
"length", std::to_string(line_count)} });
65 std::ostringstream oss;
66 auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
67 writer->BeginObject();
68 writer->WriteObjectKeyValue(
"target", std::string(
"predictor"));
69 writer->WriteObjectKeyValue(
"sources", source_list);
71 files_[
"recipe.json"] = oss.str();
73 cm.files = std::move(files_);
78 int num_output_group_;
79 std::string pred_tranform_func_;
80 std::unordered_map<std::string, std::string> files_;
82 void WalkAST(
const ASTNode* node,
83 const std::string& dest,
91 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
92 HandleMainNode(t1, dest, indent);
93 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
94 HandleACNode(t2, dest, indent);
95 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
96 HandleCondNode(t3, dest, indent);
97 }
else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
98 HandleOutputNode(t4, dest, indent);
99 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
100 HandleTUNode(t5, dest, indent);
101 }
else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
102 HandleQNode(t6, dest, indent);
104 LOG(FATAL) <<
"oops";
108 void HandleMainNode(
const MainNode* node,
109 const std::string& dest,
111 const std::string prototype
112 = (num_output_group_ > 1) ?
113 "size_t predict_multiclass(union Entry* data, int pred_margin, " 115 :
"float predict(union Entry* data, int pred_margin)";
116 files_[dest] += std::string(indent,
' ') +
"#include \"header.h\"\n\n";
117 files_[dest] += get_num_output_group_func(num_output_group_) +
"\n" 118 + get_num_feature_func(node->num_feature) +
"\n" 119 + pred_tranform_func_ +
"\n" 120 + std::string(indent,
' ') + prototype +
" {\n";
121 files_[
"header.h"] += get_num_output_group_func_prototype();
122 files_[
"header.h"] += get_num_feature_func_prototype();
123 files_[
"header.h"] += prototype +
";\n";
124 CHECK_EQ(node->children.size(), 1);
125 WalkAST(node->children[0], dest, indent + 2);
126 std::ostringstream oss;
127 if (num_output_group_ > 1) {
128 oss << std::string(indent + 2,
' ')
129 <<
"for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 130 << std::string(indent + 4,
' ') <<
"result[i] = sum[i]";
131 if (node->average_result) {
132 oss <<
" / " << node->num_tree;
134 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 135 << std::string(indent + 2,
' ') <<
"}\n" 136 << std::string(indent + 2,
' ') <<
"if (!pred_margin) {\n" 137 << std::string(indent + 2,
' ')
138 <<
" return pred_transform(result);\n" 139 << std::string(indent + 2,
' ') <<
"} else {\n" 140 << std::string(indent + 2,
' ')
141 <<
" return " << num_output_group_ <<
";\n" 142 << std::string(indent + 2,
' ') <<
"}\n";
144 oss << std::string(indent + 2,
' ') <<
"sum = sum";
145 if (node->average_result) {
146 oss <<
" / " << node->num_tree;
148 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 149 << std::string(indent + 2,
' ') <<
"if (!pred_margin) {\n" 150 << std::string(indent + 2,
' ') <<
" return pred_transform(sum);\n" 151 << std::string(indent + 2,
' ') <<
"} else {\n" 152 << std::string(indent + 2,
' ') <<
" return sum;\n" 153 << std::string(indent + 2,
' ') <<
"}\n";
155 oss << std::string(indent,
' ') <<
"}\n";
156 files_[dest] += oss.str();
160 const std::string& dest,
162 std::ostringstream oss;
163 if (num_output_group_ > 1) {
164 oss << std::string(indent,
' ')
165 <<
"float sum[" << num_output_group_ <<
"] = {0.0f};\n";
167 oss << std::string(indent,
' ') <<
"float sum = 0.0f;\n";
169 oss << std::string(indent,
' ') <<
"unsigned int tmp;\n";
170 files_[dest] += oss.str();
171 for (
ASTNode* child : node->children) {
172 WalkAST(child, dest, indent);
177 const std::string& dest,
179 const unsigned split_index = node->split_index;
180 const std::string na_check
181 = std::string(
"data[") + std::to_string(split_index) +
"].missing != -1";
184 std::ostringstream oss;
185 if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
187 oss <<
"data[" << split_index <<
"].qvalue " 188 << OpName(t->op) <<
" " << t->threshold.int_val;
191 const std::streamsize ss = std::cout.precision();
192 oss <<
"data[" << split_index <<
"].fvalue " 193 << OpName(t->op) <<
" " 194 << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
195 << t->threshold.float_val << std::setprecision(ss);
201 std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
202 CHECK_GE(bitmap.size(), 1);
203 bool all_zeros =
true;
204 for (uint64_t e : bitmap) {
205 all_zeros &= (e == 0);
210 oss <<
"(tmp = (unsigned int)(data[" << split_index <<
"].fvalue) ), " 211 <<
"(tmp >= 0 && tmp < 64 && (( (uint64_t)" 212 << bitmap[0] <<
"U >> tmp) & 1) )";
213 for (
size_t i = 1; i < bitmap.size(); ++i) {
214 oss <<
" || (tmp >= " << (i * 64)
215 <<
" && tmp < " << ((i + 1) * 64)
216 <<
" && (( (uint64_t)" << bitmap[i]
217 <<
"U >> (tmp - " << (i * 64) <<
") ) & 1) )";
222 = ((node->default_left) ? (std::string(
"!(") + na_check +
") || (")
223 : (std::string(
" (") + na_check +
") && ("))
225 switch (node->branch_hint) {
226 case BranchHint::kLikely:
227 cond = std::string(
" LIKELY( ") + cond +
" ) ";
229 case BranchHint::kUnlikely:
230 cond = std::string(
" UNLIKELY( ") + cond +
" ) ";
233 files_[dest] += std::string(indent,
' ') +
"if (" + cond +
") {\n";
234 CHECK_EQ(node->children.size(), 2);
235 WalkAST(node->children[0], dest, indent + 2);
236 files_[dest] += std::string(indent,
' ') +
"} else {\n";
237 WalkAST(node->children[1], dest, indent + 2);
238 files_[dest] += std::string(indent,
' ') +
"}\n";
242 const std::string& dest,
244 std::ostringstream oss;
245 if (num_output_group_ > 1) {
246 if (node->is_vector) {
248 const std::vector<tl_float>& leaf_vector = node->vector;
249 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group_))
250 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
251 for (
int group_id = 0; group_id < num_output_group_; ++group_id) {
252 oss << std::string(indent,
' ')
253 <<
"sum[" << group_id <<
"] += (float)" 254 << common::ToString(leaf_vector[group_id]) <<
";\n";
258 oss << std::string(indent,
' ') <<
"sum[" 259 << (node->tree_id % num_output_group_) <<
"] += (float)" 260 << common::ToString(node->scalar) <<
";\n";
263 oss << std::string(indent,
' ') <<
"sum += (float)" 264 << common::ToString(node->scalar) <<
";\n";
266 files_[dest] += oss.str();
267 CHECK_EQ(node->children.size(), 0);
271 const std::string& dest,
273 const std::string new_file
274 = std::string(
"tu") + std::to_string(node->unit_id) +
".c";
275 std::ostringstream caller_buf, callee_buf, func_name, prototype;
276 callee_buf <<
"#include \"header.h\"\n";
277 if (num_output_group_ > 1) {
278 func_name <<
"predict_margin_multiclass_unit" << node->unit_id;
279 caller_buf << std::string(indent,
' ')
280 << func_name.str() <<
"(data, sum);\n";
281 prototype <<
"void " << func_name.str()
282 <<
"(union Entry* data, float* result)";
284 func_name <<
"predict_margin_unit" << node->unit_id;
285 caller_buf << std::string(indent,
' ')
286 <<
"sum += " << func_name.str() <<
"(data);\n";
287 prototype <<
"float " << func_name.str() <<
"(union Entry* data)";
289 callee_buf << prototype.str() <<
" {\n";
290 files_[dest] += caller_buf.str();
291 files_[new_file] += callee_buf.str();
292 CHECK_EQ(node->children.size(), 1);
293 WalkAST(node->children[0], new_file, 2);
294 callee_buf.str(
""); callee_buf.clear();
295 if (num_output_group_ > 1) {
296 callee_buf <<
" for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 297 <<
" result[i] = sum[i];\n" 300 callee_buf <<
" return sum;\n";
303 files_[new_file] += callee_buf.str();
304 files_[
"header.h"] += prototype.str() +
";\n";
308 const std::string& dest,
310 std::ostringstream oss;
311 const int num_feature = node->is_categorical.size();
313 oss <<
"static const unsigned char is_categorical[] = {\n ";
314 for (
int fid = 0; fid < num_feature; ++fid) {
315 if (node->is_categorical[fid]) {
316 common::WrapText(&oss, &length,
"1, ", 2, 80);
318 common::WrapText(&oss, &length,
"0, ", 2, 80);
323 oss <<
"static const float threshold[] = {\n ";
324 for (
const auto& e : node->cut_pts) {
326 common::WrapText(&oss, &length, common::ToString(v) +
", ", 2, 80);
332 oss <<
"static const int th_begin[] = {\n ";
333 for (
const auto& e : node->cut_pts) {
334 common::WrapText(&oss, &length, std::to_string(accum) +
", ", 2, 80);
339 oss <<
"static const int th_len[] = {\n ";
340 for (
const auto& e : node->cut_pts) {
341 common::WrapText(&oss, &length, std::to_string(e.size()) +
", ", 2, 80);
344 #include "./native/quantize_func.h" 345 oss << quantize_func << files_[dest] << std::string(indent,
' ')
346 <<
"for (int i = 0; i < " << num_feature <<
"; ++i) {\n" 347 << std::string(indent + 2,
' ')
348 <<
"if (data[i].missing != -1 && !is_categorical[i]) {\n" 349 << std::string(indent + 4,
' ')
350 <<
"data[i].qvalue = quantize(data[i].fvalue, i);\n" 351 << std::string(indent + 2,
' ') +
"}\n" 352 << std::string(indent,
' ') +
"}\n";
353 files_[dest] = oss.str();
354 CHECK_EQ(node->children.size(), 1);
355 WalkAST(node->children[0], dest, indent);
358 inline std::vector<uint64_t>
359 to_bitmap(
const std::vector<uint32_t>& left_categories)
const {
360 const size_t num_left_categories = left_categories.size();
361 const uint32_t max_left_category = left_categories[num_left_categories - 1];
362 std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
363 for (
size_t i = 0; i < left_categories.size(); ++i) {
364 const uint32_t cat = left_categories[i];
365 const size_t idx = cat / 64;
366 const uint32_t offset = cat % 64;
367 bitmap[idx] |= (
static_cast<uint64_t
>(1) << offset);
374 .describe(
"AST-based compiler that produces C code")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
parameters for tree compiler
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::vector< std::vector< size_t > > Get() const
fetch branch annotation. Usage example:
void Load(dmlc::Stream *fi)
load branch annotation from a JSON file
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages