3 #include <unordered_map> 6 #include "./pred_transform.h" 7 #include "./ast/builder.h" 8 #include "./java/get_num_feature.h" 9 #include "./java/get_num_output_group.h" 14 DMLC_REGISTRY_FILE_TAG(ast_java);
21 LOG(INFO) <<
"Using ASTJavaCompiler";
28 cm.files[
"Main.java"] =
"";
31 LOG(WARNING) <<
"\033[1;31mWARNING: Java backend does not support " 32 <<
"branch annotation.\u001B[0m The parameter " 33 <<
"\x1B[33mannotate_in\u001B[0m will be ignored.";
36 pred_tranform_func_ = PredTransformFunction(
"java", model);
43 builder.QuantizeThresholds();
45 builder.CountDescendant();
46 builder.BreakUpLargeUnits(param.max_unit_size);
47 #include "java/entry_type.h" 48 #include "java/pom_xml.h" 49 files_[file_prefix_ +
"Entry.java"] = entry_type;
50 files_[
"pom.xml"] = pom_xml;
51 WalkAST(builder.GetRootNode(),
"Main.java", 0);
53 cm.files = std::move(files_);
54 cm.file_prefix = file_prefix_;
59 int num_output_group_;
60 std::string pred_tranform_func_;
61 std::unordered_map<std::string, std::string> files_;
62 std::string main_tail_;
63 const std::string file_prefix_ =
"src/main/java/treelite/predictor/";
65 void WalkAST(
const ASTNode* node,
66 const std::string& dest,
74 if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
75 HandleMainNode(t1, dest, indent);
76 }
else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
77 HandleACNode(t2, dest, indent);
78 }
else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
79 HandleCondNode(t3, dest, indent);
80 }
else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
81 HandleOutputNode(t4, dest, indent);
82 }
else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
83 HandleTUNode(t5, dest, indent);
84 }
else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
85 HandleQNode(t6, dest, indent);
91 void CommitToFile(
const std::string& dest,
92 const std::string& content) {
93 files_[file_prefix_ + dest] += content;
96 void HandleMainNode(
const MainNode* node,
97 const std::string& dest,
99 const std::string prototype
100 = (num_output_group_ > 1) ?
101 "public static long predict_multiclass(Entry[] data, " 102 "boolean pred_margin, " 104 :
"public static float predict(Entry[] data, boolean pred_margin)";
106 "package treelite.predictor;\n\n" 107 "import java.lang.Math;\n" 108 "import javolution.context.LogContext;\n" 109 "import javolution.context.LogContext.Level;\n\n" 110 "public class Main {\n");
112 " static {\n LogContext ctx = LogContext.enter();\n" 113 " ctx.setLevel(Level.INFO);\n }\n");
115 get_num_output_group_func(num_output_group_) +
"\n" 116 + get_num_feature_func(node->num_feature) +
"\n" 117 + pred_tranform_func_ +
"\n" 118 + std::string(indent + 2,
' ') + prototype +
" {\n");
119 CHECK_EQ(node->children.size(), 1);
120 WalkAST(node->children[0], dest, indent + 4);
121 std::ostringstream oss;
122 if (num_output_group_ > 1) {
123 oss << std::string(indent + 4,
' ')
124 <<
"for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 125 << std::string(indent + 6,
' ') <<
"result[i] = sum[i]";
126 if (node->average_result) {
127 oss <<
" / " << node->num_tree;
129 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 130 << std::string(indent + 4,
' ') <<
"}\n" 131 << std::string(indent + 4,
' ') <<
"if (!pred_margin) {\n" 132 << std::string(indent + 4,
' ')
133 <<
" return pred_transform(result);\n" 134 << std::string(indent + 4,
' ') <<
"} else {\n" 135 << std::string(indent + 4,
' ')
136 <<
" return " << num_output_group_ <<
";\n" 137 << std::string(indent + 4,
' ') <<
"}\n";
139 oss << std::string(indent + 4,
' ') <<
"sum = sum";
140 if (node->average_result) {
141 oss <<
" / " << node->num_tree;
143 oss <<
" + (float)(" << common::ToString(node->global_bias) <<
");\n" 144 << std::string(indent + 4,
' ') <<
"if (!pred_margin) {\n" 145 << std::string(indent + 4,
' ') <<
" return pred_transform(sum);\n" 146 << std::string(indent + 4,
' ') <<
"} else {\n" 147 << std::string(indent + 4,
' ') <<
" return sum;\n" 148 << std::string(indent + 4,
' ') <<
"}\n";
150 oss << std::string(indent + 2,
' ') <<
"}\n" 152 << std::string(indent,
' ') <<
"}\n";
153 CommitToFile(dest, oss.str());
157 const std::string& dest,
159 std::ostringstream oss;
160 if (num_output_group_ > 1) {
161 oss << std::string(indent,
' ')
162 <<
"float[] sum = new float[" << num_output_group_ <<
"];\n";
164 oss << std::string(indent,
' ') <<
"float sum = 0.0f;\n";
166 oss << std::string(indent,
' ') <<
"int tmp;\n";
167 CommitToFile(dest, oss.str());
169 for (
ASTNode* child : node->children) {
170 WalkAST(child, dest, indent);
175 const std::string& dest,
177 const unsigned split_index = node->split_index;
178 const std::string na_check
179 = std::string(
"data[") + std::to_string(split_index)
180 +
"].missing.get() != -1";
183 std::ostringstream oss;
184 if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
186 oss <<
"data[" << split_index <<
"].qvalue.get() " 187 << OpName(t->op) <<
" " << t->threshold.int_val;
188 }
else if (std::isinf(t->threshold.float_val)) {
191 oss << (common::CompareWithOp(0.0, t->op, t->threshold.float_val)
195 const std::streamsize ss = std::cout.precision();
196 oss <<
"data[" << split_index <<
"].fvalue.get() " 197 << OpName(t->op) <<
" " 198 << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
199 << t->threshold.float_val <<
"f" << std::setprecision(ss);
205 std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
206 CHECK_GE(bitmap.size(), 1);
207 bool all_zeros =
true;
208 for (uint64_t e : bitmap) {
209 all_zeros &= (e == 0);
214 oss <<
"(tmp = (int)(data[" << split_index <<
"].fvalue.get()) ), " 215 <<
"(tmp >= 0 && tmp < 64 && (( (long)" 216 << bitmap[0] <<
"L >>> tmp) & 1) )";
217 for (
size_t i = 1; i < bitmap.size(); ++i) {
218 oss <<
" || (tmp >= " << (i * 64)
219 <<
" && tmp < " << ((i + 1) * 64)
220 <<
" && (( (long)" << bitmap[i]
221 <<
"L >>> (tmp - " << (i * 64) <<
") ) & 1) )";
226 std::string(indent,
' ') +
"if (" 227 + ((node->default_left) ? (std::string(
"!(") + na_check +
") || (")
228 : (std::string(
" (") + na_check +
") && ("))
229 + oss.str() +
") ) {\n");
230 CHECK_EQ(node->children.size(), 2);
231 WalkAST(node->children[0], dest, indent + 2);
232 CommitToFile(dest, std::string(indent,
' ') +
"} else {\n");
233 WalkAST(node->children[1], dest, indent + 2);
234 CommitToFile(dest, std::string(indent,
' ') +
"}\n");
238 const std::string& dest,
240 std::ostringstream oss;
241 if (num_output_group_ > 1) {
242 if (node->is_vector) {
244 const std::vector<tl_float>& leaf_vector = node->vector;
245 CHECK_EQ(leaf_vector.size(),
static_cast<size_t>(num_output_group_))
246 <<
"Ill-formed model: leaf vector must be of length [num_output_group]";
247 for (
int group_id = 0; group_id < num_output_group_; ++group_id) {
248 oss << std::string(indent,
' ')
249 <<
"sum[" << group_id <<
"] += " 250 << common::ToString(leaf_vector[group_id]) <<
"f;\n";
254 oss << std::string(indent,
' ') <<
"sum[" 255 << (node->tree_id % num_output_group_) <<
"] += " 256 << common::ToString(node->scalar) <<
"f;\n";
259 oss << std::string(indent,
' ') <<
"sum += " 260 << common::ToString(node->scalar) <<
"f;\n";
262 CommitToFile(dest, oss.str());
263 CHECK_EQ(node->children.size(), 0);
267 const std::string& dest,
269 const std::string new_file
270 = std::string(
"TU") + std::to_string(node->unit_id) +
".java";
271 std::ostringstream caller_buf, callee_buf, func_name, class_name, prototype;
272 class_name <<
"TU" << node->unit_id;
273 if (num_output_group_ > 1) {
274 func_name <<
"predict_margin_multiclass_unit" << node->unit_id;
275 caller_buf << std::string(indent,
' ')
276 << class_name.str() <<
"." 277 << func_name.str() <<
"(data, sum);\n";
278 prototype <<
"public static void " << func_name.str()
279 <<
"(Entry[] data, float[] result)";
281 func_name <<
"predict_margin_unit" << node->unit_id;
282 caller_buf << std::string(indent,
' ')
283 <<
"sum += " << class_name.str() <<
"." << func_name.str()
285 prototype <<
"public static float " << func_name.str()
288 callee_buf <<
"package treelite.predictor;\n\n" 289 <<
"public class " << class_name.str() <<
" {\n" 290 <<
" " << prototype.str() <<
" {\n";
291 CommitToFile(dest, caller_buf.str());
292 CommitToFile(new_file, callee_buf.str());
293 CHECK_EQ(node->children.size(), 1);
294 WalkAST(node->children[0], new_file, 4);
295 callee_buf.str(
""); callee_buf.clear();
296 if (num_output_group_ > 1) {
298 <<
" for (int i = 0; i < " << num_output_group_ <<
"; ++i) {\n" 299 <<
" result[i] = sum[i];\n" 302 callee_buf <<
" return sum;\n";
304 callee_buf <<
" }\n" <<
"}\n";
305 CommitToFile(new_file, callee_buf.str());
309 const std::string& dest,
311 std::ostringstream oss;
312 const int num_feature = node->is_categorical.size();
314 oss <<
" private static final boolean[] is_categorical = {\n ";
315 for (
int fid = 0; fid < num_feature; ++fid) {
316 if (node->is_categorical[fid]) {
317 common::WrapText(&oss, &length,
"true, ", 4, 80);
319 common::WrapText(&oss, &length,
"false, ", 4, 80);
324 oss <<
" private static final float[] threshold = {\n ";
325 for (
const auto& e : node->cut_pts) {
327 common::WrapText(&oss, &length, common::ToString(v) +
"f, ", 4, 80);
333 oss <<
" private static final int[] th_begin = {\n ";
334 for (
const auto& e : node->cut_pts) {
335 common::WrapText(&oss, &length, std::to_string(accum) +
", ", 4, 80);
340 oss <<
" private static final int[] th_len = {\n ";
341 for (
const auto& e : node->cut_pts) {
342 common::WrapText(&oss, &length, std::to_string(e.size()) +
", ", 4, 80);
345 #include "./java/quantize_func.h" 346 oss << quantize_func;
347 main_tail_ += oss.str();
348 oss.str(
""); oss.clear();
349 oss << std::string(indent,
' ')
350 <<
"for (int i = 0; i < " << num_feature <<
"; ++i) {\n" 351 << std::string(indent + 2,
' ')
352 <<
"if (data[i].missing.get() != -1 && !is_categorical[i]) {\n" 353 << std::string(indent + 4,
' ')
354 <<
"data[i].qvalue.set(quantize(data[i].fvalue.get(), i));\n" 355 << std::string(indent + 2,
' ') +
"}\n" 356 << std::string(indent,
' ') +
"}\n";
357 CommitToFile(dest, oss.str());
358 CHECK_EQ(node->children.size(), 1);
359 WalkAST(node->children[0], dest, indent);
362 inline std::vector<uint64_t>
363 to_bitmap(
const std::vector<uint32_t>& left_categories)
const {
364 const size_t num_left_categories = left_categories.size();
365 const uint32_t max_left_category = left_categories[num_left_categories - 1];
366 std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
367 for (
size_t i = 0; i < left_categories.size(); ++i) {
368 const uint32_t cat = left_categories[i];
369 const size_t idx = cat / 64;
370 const uint32_t offset = cat % 64;
371 bitmap[idx] |= (
static_cast<uint64_t
>(1) << offset);
378 .describe(
"AST-based compiler that produces Java code")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
parameters for tree compiler
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
CompiledModel Compile(const Model &model) override
convert tree ensemble model