2 #include <unordered_map> 4 #include "./pred_transform.h" 5 #include "./ast/builder.h" 6 #include "./java/get_num_feature.h" 7 #include "./java/get_num_output_group.h" 12 DMLC_REGISTRY_FILE_TAG(ast_java);
19 LOG(INFO) <<
"Using ASTJavaCompiler";
26 cm.files[
"Main.java"] =
"";
29 LOG(WARNING) <<
"\033[1;31mWARNING: Java backend does not support " 30 <<
"branch annotation.\u001B[0m The parameter " 31 <<
"\x1B[33mannotate_in\u001B[0m will be ignored.";
34 pred_tranform_func_ = PredTransformFunction(
"java", model);
41 builder.QuantizeThresholds();
43 builder.CountDescendant();
44 builder.BreakUpLargeUnits(param.max_unit_size);
45 #include "java/entry_type.h" 46 #include "java/pom_xml.h" 47 files_[file_prefix_ +
"Entry.java"] = entry_type;
48 files_[
"pom.xml"] = pom_xml;
49 WalkAST(builder.GetRootNode(),
"Main.java", 0);
51 cm.files = std::move(files_);
52 cm.file_prefix = file_prefix_;
57 int num_output_group_;
58 std::string pred_tranform_func_;
59 std::unordered_map<std::string, std::string> files_;
60 std::string main_tail_;
61 const std::string file_prefix_ =
"src/main/java/treelite/predictor/";
63 void WalkAST(
const ASTNode* node,
64 const std::string& dest,
72 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
73 HandleMainNode(t1, dest, indent);
74 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
75 HandleACNode(t2, dest, indent);
76 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
77 HandleCondNode(t3, dest, indent);
78 }
else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
79 HandleOutputNode(t4, dest, indent);
80 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
81 HandleTUNode(t5, dest, indent);
82 }
else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
83 HandleQNode(t6, dest, indent);
89 void CommitToFile(
const std::string& dest,
90 const std::string& content) {
91 files_[file_prefix_ + dest] += content;
94 void HandleMainNode(
const MainNode* node,
95 const std::string& dest,
97 const std::string prototype
98 = (num_output_group_ > 1) ?
99 "public static long predict_multiclass(Entry[] data, " 100 "boolean pred_margin, " 102 :
"public static float predict(Entry[] data, boolean pred_margin)";
104 "package treelite.predictor;\n\n" 105 "import java.lang.Math;\n" 106 "import javolution.context.LogContext;\n" 107 "import javolution.context.LogContext.Level;\n\n" 108 "public class Main {\n");
110 " static {\n LogContext ctx = LogContext.enter();\n" 111 " ctx.setLevel(Level.INFO);\n }\n");
113 get_num_output_group_func(num_output_group_) +
"\n" 114 + get_num_feature_func(node->num_feature) +
"\n" 115 + pred_tranform_func_ +
"\n" 116 + std::string(indent + 2,
' ') + prototype +
" {\n");
117 CHECK_EQ(node->children.size(), 1);
118 WalkAST(node->children[0], dest, indent + 4);
119 std::ostringstream oss;
120 if (num_output_group_ > 1) {
121 oss << std::string(indent + 4,
' ')
122 <<
"for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 123 << std::string(indent + 6,
' ') <<
"result[i] = sum[i]";
124 if (node->average_result) {
125 oss <<
" / " << node->num_tree;
127 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 128 << std::string(indent + 4,
' ') <<
"}\n" 129 << std::string(indent + 4,
' ') <<
"if (!pred_margin) {\n" 130 << std::string(indent + 4,
' ')
131 <<
" return pred_transform(result);\n" 132 << std::string(indent + 4,
' ') <<
"} else {\n" 133 << std::string(indent + 4,
' ')
134 <<
" return " << num_output_group_ <<
";\n" 135 << std::string(indent + 4,
' ') <<
"}\n";
137 oss << std::string(indent + 4,
' ') <<
"sum = sum";
138 if (node->average_result) {
139 oss <<
" / " << node->num_tree;
141 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 142 << std::string(indent + 4,
' ') <<
"if (!pred_margin) {\n" 143 << std::string(indent + 4,
' ') <<
" return pred_transform(sum);\n" 144 << std::string(indent + 4,
' ') <<
"} else {\n" 145 << std::string(indent + 4,
' ') <<
" return sum;\n" 146 << std::string(indent + 4,
' ') <<
"}\n";
148 oss << std::string(indent + 2,
' ') <<
"}\n" 150 << std::string(indent,
' ') <<
"}\n";
151 CommitToFile(dest, oss.str());
155 const std::string& dest,
157 std::ostringstream oss;
158 if (num_output_group_ > 1) {
159 oss << std::string(indent,
' ')
160 <<
"float[] sum = new float[" << num_output_group_ <<
"];\n";
162 oss << std::string(indent,
' ') <<
"float sum = 0.0f;\n";
164 oss << std::string(indent,
' ') <<
"int tmp;\n";
165 CommitToFile(dest, oss.str());
167 for (
ASTNode* child : node->children) {
168 WalkAST(child, dest, indent);
173 const std::string& dest,
175 const unsigned split_index = node->split_index;
176 const std::string na_check
177 = std::string(
"data[") + std::to_string(split_index)
178 +
"].missing.get() != -1";
181 std::ostringstream oss;
182 if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
184 oss <<
"data[" << split_index <<
"].qvalue.get() " 185 << OpName(t->op) <<
" " << t->threshold.int_val;
188 const std::streamsize ss = std::cout.precision();
189 oss <<
"data[" << split_index <<
"].fvalue.get() " 190 << OpName(t->op) <<
" " 191 << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
192 << t->threshold.float_val <<
"f" << std::setprecision(ss);
198 std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
199 CHECK_GE(bitmap.size(), 1);
200 bool all_zeros =
true;
201 for (uint64_t e : bitmap) {
202 all_zeros &= (e == 0);
207 oss <<
"(tmp = (int)(data[" << split_index <<
"].fvalue.get()) ), " 208 <<
"(tmp >= 0 && tmp < 64 && (( (long)" 209 << bitmap[0] <<
"L >>> tmp) & 1) )";
210 for (
size_t i = 1; i < bitmap.size(); ++i) {
211 oss <<
" || (tmp >= " << (i * 64)
212 <<
" && tmp < " << ((i + 1) * 64)
213 <<
" && (( (long)" << bitmap[i]
214 <<
"L >>> (tmp - " << (i * 64) <<
") ) & 1) )";
219 std::string(indent,
' ') +
"if (" 220 + ((node->default_left) ? (std::string(
"!(") + na_check +
") || (")
221 : (std::string(
" (") + na_check +
") && ("))
222 + oss.str() +
") ) {\n");
223 CHECK_EQ(node->children.size(), 2);
224 WalkAST(node->children[0], dest, indent + 2);
225 CommitToFile(dest, std::string(indent,
' ') +
"} else {\n");
226 WalkAST(node->children[1], dest, indent + 2);
227 CommitToFile(dest, std::string(indent,
' ') +
"}\n");
231 const std::string& dest,
233 std::ostringstream oss;
234 if (num_output_group_ > 1) {
235 if (node->is_vector) {
237 const std::vector<tl_float>& leaf_vector = node->vector;
238 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group_))
239 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
240 for (
int group_id = 0; group_id < num_output_group_; ++group_id) {
241 oss << std::string(indent,
' ')
242 <<
"sum[" << group_id <<
"] += " 243 << common::ToString(leaf_vector[group_id]) <<
"f;\n";
247 oss << std::string(indent,
' ') <<
"sum[" 248 << (node->tree_id % num_output_group_) <<
"] += " 249 << common::ToString(node->scalar) <<
"f;\n";
252 oss << std::string(indent,
' ') <<
"sum += " 253 << common::ToString(node->scalar) <<
"f;\n";
255 CommitToFile(dest, oss.str());
256 CHECK_EQ(node->children.size(), 0);
260 const std::string& dest,
262 const std::string new_file
263 = std::string(
"TU") + std::to_string(node->unit_id) +
".java";
264 std::ostringstream caller_buf, callee_buf, func_name, class_name, prototype;
265 class_name <<
"TU" << node->unit_id;
266 if (num_output_group_ > 1) {
267 func_name <<
"predict_margin_multiclass_unit" << node->unit_id;
268 caller_buf << std::string(indent,
' ')
269 << class_name.str() <<
"." 270 << func_name.str() <<
"(data, sum);\n";
271 prototype <<
"public static void " << func_name.str()
272 <<
"(Entry[] data, float[] result)";
274 func_name <<
"predict_margin_unit" << node->unit_id;
275 caller_buf << std::string(indent,
' ')
276 <<
"sum += " << class_name.str() <<
"." << func_name.str()
278 prototype <<
"public static float " << func_name.str()
281 callee_buf <<
"package treelite.predictor;\n\n" 282 <<
"public class " << class_name.str() <<
" {\n" 283 <<
" " << prototype.str() <<
" {\n";
284 CommitToFile(dest, caller_buf.str());
285 CommitToFile(new_file, callee_buf.str());
286 CHECK_EQ(node->children.size(), 1);
287 WalkAST(node->children[0], new_file, 4);
288 callee_buf.str(
""); callee_buf.clear();
289 if (num_output_group_ > 1) {
291 <<
" for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 292 <<
" result[i] = sum[i];\n" 295 callee_buf <<
" return sum;\n";
297 callee_buf <<
" }\n" <<
"}\n";
298 CommitToFile(new_file, callee_buf.str());
302 const std::string& dest,
304 std::ostringstream oss;
305 const int num_feature = node->is_categorical.size();
307 oss <<
" private static final boolean[] is_categorical = {\n ";
308 for (
int fid = 0; fid < num_feature; ++fid) {
309 if (node->is_categorical[fid]) {
310 common::WrapText(&oss, &length,
"true, ", 4, 80);
312 common::WrapText(&oss, &length,
"false, ", 4, 80);
317 oss <<
" private static final float[] threshold = {\n ";
318 for (
const auto& e : node->cut_pts) {
320 common::WrapText(&oss, &length, common::ToString(v) +
"f, ", 4, 80);
326 oss <<
" private static final int[] th_begin = {\n ";
327 for (
const auto& e : node->cut_pts) {
328 common::WrapText(&oss, &length, std::to_string(accum) +
", ", 4, 80);
333 oss <<
" private static final int[] th_len = {\n ";
334 for (
const auto& e : node->cut_pts) {
335 common::WrapText(&oss, &length, std::to_string(e.size()) +
", ", 4, 80);
338 #include "./java/quantize_func.h" 339 oss << quantize_func;
340 main_tail_ += oss.str();
341 oss.str(
""); oss.clear();
342 oss << std::string(indent,
' ')
343 <<
"for (int i = 0; i < " << num_feature <<
"; ++i) {\n" 344 << std::string(indent + 2,
' ')
345 <<
"if (data[i].missing.get() != -1 && !is_categorical[i]) {\n" 346 << std::string(indent + 4,
' ')
347 <<
"data[i].qvalue.set(quantize(data[i].fvalue.get(), i));\n" 348 << std::string(indent + 2,
' ') +
"}\n" 349 << std::string(indent,
' ') +
"}\n";
350 CommitToFile(dest, oss.str());
351 CHECK_EQ(node->children.size(), 1);
352 WalkAST(node->children[0], dest, indent);
355 inline std::vector<uint64_t>
356 to_bitmap(
const std::vector<uint32_t>& left_categories)
const {
357 const size_t num_left_categories = left_categories.size();
358 const uint32_t max_left_category = left_categories[num_left_categories - 1];
359 std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
360 for (
size_t i = 0; i < left_categories.size(); ++i) {
361 const uint32_t cat = left_categories[i];
362 const size_t idx = cat / 64;
363 const uint32_t offset = cat % 64;
364 bitmap[idx] |= (
static_cast<uint64_t
>(1) << offset);
371 .describe(
"AST-based compiler that produces Java code")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
parameters for tree compiler
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
CompiledModel Compile(const Model &model) override
convert tree ensemble model