treelite
ast_java.cc
1 #include <treelite/compiler.h>
2 #include <unordered_map>
3 #include "./param.h"
4 #include "./pred_transform.h"
5 #include "./ast/builder.h"
6 #include "./java/get_num_feature.h"
7 #include "./java/get_num_output_group.h"
8 
9 namespace treelite {
10 namespace compiler {
11 
12 DMLC_REGISTRY_FILE_TAG(ast_java);
13 
14 class ASTJavaCompiler : public Compiler {
15  public:
16  explicit ASTJavaCompiler(const CompilerParam& param)
17  : param(param) {
18  if (param.verbose > 0) {
19  LOG(INFO) << "Using ASTJavaCompiler";
20  }
21  }
22 
23  CompiledModel Compile(const Model& model) override {
24  CompiledModel cm;
25  cm.backend = "java";
26  cm.files["Main.java"] = "";
27 
28  if (param.annotate_in != "NULL") {
29  LOG(WARNING) << "\033[1;31mWARNING: Java backend does not support "
30  << "branch annotation.\u001B[0m The parameter "
31  << "\x1B[33mannotate_in\u001B[0m will be ignored.";
32  }
33  num_output_group_ = model.num_output_group;
34  pred_tranform_func_ = PredTransformFunction("java", model);
35  files_.clear();
36 
37  ASTBuilder builder;
38  builder.Build(model);
39  builder.Split(param.parallel_comp);
40  if (param.quantize > 0) {
41  builder.QuantizeThresholds();
42  }
43  builder.CountDescendant();
44  builder.BreakUpLargeUnits(param.max_unit_size);
45  #include "java/entry_type.h"
46  #include "java/pom_xml.h"
47  files_[file_prefix_ + "Entry.java"] = entry_type;
48  files_["pom.xml"] = pom_xml;
49  WalkAST(builder.GetRootNode(), "Main.java", 0);
50 
51  cm.files = std::move(files_);
52  cm.file_prefix = file_prefix_;
53  return cm;
54  }
55  private:
56  CompilerParam param;
57  int num_output_group_;
58  std::string pred_tranform_func_;
59  std::unordered_map<std::string, std::string> files_;
60  std::string main_tail_;
61  const std::string file_prefix_ = "src/main/java/treelite/predictor/";
62 
63  void WalkAST(const ASTNode* node,
64  const std::string& dest,
65  int indent) {
66  const MainNode* t1;
67  const AccumulatorContextNode* t2;
68  const ConditionNode* t3;
69  const OutputNode* t4;
70  const TranslationUnitNode* t5;
71  const QuantizerNode* t6;
72  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
73  HandleMainNode(t1, dest, indent);
74  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
75  HandleACNode(t2, dest, indent);
76  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
77  HandleCondNode(t3, dest, indent);
78  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
79  HandleOutputNode(t4, dest, indent);
80  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
81  HandleTUNode(t5, dest, indent);
82  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
83  HandleQNode(t6, dest, indent);
84  } else {
85  LOG(FATAL) << "oops";
86  }
87  }
88 
89  void CommitToFile(const std::string& dest,
90  const std::string& content) {
91  files_[file_prefix_ + dest] += content;
92  }
93 
94  void HandleMainNode(const MainNode* node,
95  const std::string& dest,
96  int indent) {
97  const std::string prototype
98  = (num_output_group_ > 1) ?
99  "public static long predict_multiclass(Entry[] data, "
100  "boolean pred_margin, "
101  "float[] result)"
102  : "public static float predict(Entry[] data, boolean pred_margin)";
103  CommitToFile(dest,
104  "package treelite.predictor;\n\n"
105  "import java.lang.Math;\n"
106  "import javolution.context.LogContext;\n"
107  "import javolution.context.LogContext.Level;\n\n"
108  "public class Main {\n");
109  CommitToFile(dest,
110  " static {\n LogContext ctx = LogContext.enter();\n"
111  " ctx.setLevel(Level.INFO);\n }\n");
112  CommitToFile(dest,
113  get_num_output_group_func(num_output_group_) + "\n"
114  + get_num_feature_func(node->num_feature) + "\n"
115  + pred_tranform_func_ + "\n"
116  + std::string(indent + 2, ' ') + prototype + " {\n");
117  CHECK_EQ(node->children.size(), 1);
118  WalkAST(node->children[0], dest, indent + 4);
119  std::ostringstream oss;
120  if (num_output_group_ > 1) {
121  oss << std::string(indent + 4, ' ')
122  << "for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
123  << std::string(indent + 6, ' ') << "result[i] = sum[i]";
124  if (node->average_result) {
125  oss << " / " << node->num_tree;
126  }
127  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
128  << std::string(indent + 4, ' ') << "}\n"
129  << std::string(indent + 4, ' ') << "if (!pred_margin) {\n"
130  << std::string(indent + 4, ' ')
131  << " return pred_transform(result);\n"
132  << std::string(indent + 4, ' ') << "} else {\n"
133  << std::string(indent + 4, ' ')
134  << " return " << num_output_group_ << ";\n"
135  << std::string(indent + 4, ' ') << "}\n";
136  } else {
137  oss << std::string(indent + 4, ' ') << "sum = sum";
138  if (node->average_result) {
139  oss << " / " << node->num_tree;
140  }
141  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
142  << std::string(indent + 4, ' ') << "if (!pred_margin) {\n"
143  << std::string(indent + 4, ' ') << " return pred_transform(sum);\n"
144  << std::string(indent + 4, ' ') << "} else {\n"
145  << std::string(indent + 4, ' ') << " return sum;\n"
146  << std::string(indent + 4, ' ') << "}\n";
147  }
148  oss << std::string(indent + 2, ' ') << "}\n"
149  << main_tail_
150  << std::string(indent, ' ') << "}\n";
151  CommitToFile(dest, oss.str());
152  }
153 
154  void HandleACNode(const AccumulatorContextNode* node,
155  const std::string& dest,
156  int indent) {
157  std::ostringstream oss;
158  if (num_output_group_ > 1) {
159  oss << std::string(indent, ' ')
160  << "float[] sum = new float[" << num_output_group_ << "];\n";
161  } else {
162  oss << std::string(indent, ' ') << "float sum = 0.0f;\n";
163  }
164  oss << std::string(indent, ' ') << "int tmp;\n";
165  CommitToFile(dest, oss.str());
166 
167  for (ASTNode* child : node->children) {
168  WalkAST(child, dest, indent);
169  }
170  }
171 
172  void HandleCondNode(const ConditionNode* node,
173  const std::string& dest,
174  int indent) {
175  const unsigned split_index = node->split_index;
176  const std::string na_check
177  = std::string("data[") + std::to_string(split_index)
178  + "].missing.get() != -1";
179 
180  const NumericalConditionNode* t;
181  std::ostringstream oss; // prepare logical statement for evaluating split
182  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
183  if (t->quantized) {
184  oss << "data[" << split_index << "].qvalue.get() "
185  << OpName(t->op) << " " << t->threshold.int_val;
186  } else {
187  // to restore default precision
188  const std::streamsize ss = std::cout.precision();
189  oss << "data[" << split_index << "].fvalue.get() "
190  << OpName(t->op) << " "
191  << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
192  << t->threshold.float_val << "f" << std::setprecision(ss);
193  }
194  } else { // categorical split
195  const CategoricalConditionNode* t2
196  = dynamic_cast<const CategoricalConditionNode*>(node);
197  CHECK(t2);
198  std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
199  CHECK_GE(bitmap.size(), 1);
200  bool all_zeros = true;
201  for (uint64_t e : bitmap) {
202  all_zeros &= (e == 0);
203  }
204  if (all_zeros) {
205  oss << "0";
206  } else {
207  oss << "(tmp = (int)(data[" << split_index << "].fvalue.get()) ), "
208  << "(tmp >= 0 && tmp < 64 && (( (long)"
209  << bitmap[0] << "L >>> tmp) & 1) )";
210  for (size_t i = 1; i < bitmap.size(); ++i) {
211  oss << " || (tmp >= " << (i * 64)
212  << " && tmp < " << ((i + 1) * 64)
213  << " && (( (long)" << bitmap[i]
214  << "L >>> (tmp - " << (i * 64) << ") ) & 1) )";
215  }
216  }
217  }
218  CommitToFile(dest,
219  std::string(indent, ' ') + "if ("
220  + ((node->default_left) ? (std::string("!(") + na_check + ") || (")
221  : (std::string(" (") + na_check + ") && ("))
222  + oss.str() + ") ) {\n");
223  CHECK_EQ(node->children.size(), 2);
224  WalkAST(node->children[0], dest, indent + 2);
225  CommitToFile(dest, std::string(indent, ' ') + "} else {\n");
226  WalkAST(node->children[1], dest, indent + 2);
227  CommitToFile(dest, std::string(indent, ' ') + "}\n");
228  }
229 
230  void HandleOutputNode(const OutputNode* node,
231  const std::string& dest,
232  int indent) {
233  std::ostringstream oss;
234  if (num_output_group_ > 1) {
235  if (node->is_vector) {
236  // multi-class classification with random forest
237  const std::vector<tl_float>& leaf_vector = node->vector;
238  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group_))
239  << "Ill-formed model: leaf vector must be of length [num_output_group]";
240  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
241  oss << std::string(indent, ' ')
242  << "sum[" << group_id << "] += "
243  << common::ToString(leaf_vector[group_id]) << "f;\n";
244  }
245  } else {
246  // multi-class classification with gradient boosted trees
247  oss << std::string(indent, ' ') << "sum["
248  << (node->tree_id % num_output_group_) << "] += "
249  << common::ToString(node->scalar) << "f;\n";
250  }
251  } else {
252  oss << std::string(indent, ' ') << "sum += "
253  << common::ToString(node->scalar) << "f;\n";
254  }
255  CommitToFile(dest, oss.str());
256  CHECK_EQ(node->children.size(), 0);
257  }
258 
259  void HandleTUNode(const TranslationUnitNode* node,
260  const std::string& dest,
261  int indent) {
262  const std::string new_file
263  = std::string("TU") + std::to_string(node->unit_id) + ".java";
264  std::ostringstream caller_buf, callee_buf, func_name, class_name, prototype;
265  class_name << "TU" << node->unit_id;
266  if (num_output_group_ > 1) {
267  func_name << "predict_margin_multiclass_unit" << node->unit_id;
268  caller_buf << std::string(indent, ' ')
269  << class_name.str() << "."
270  << func_name.str() << "(data, sum);\n";
271  prototype << "public static void " << func_name.str()
272  << "(Entry[] data, float[] result)";
273  } else {
274  func_name << "predict_margin_unit" << node->unit_id;
275  caller_buf << std::string(indent, ' ')
276  << "sum += " << class_name.str() << "." << func_name.str()
277  << "(data);\n";
278  prototype << "public static float " << func_name.str()
279  << "(Entry[] data)";
280  }
281  callee_buf << "package treelite.predictor;\n\n"
282  << "public class " << class_name.str() << " {\n"
283  << " " << prototype.str() << " {\n";
284  CommitToFile(dest, caller_buf.str());
285  CommitToFile(new_file, callee_buf.str());
286  CHECK_EQ(node->children.size(), 1);
287  WalkAST(node->children[0], new_file, 4);
288  callee_buf.str(""); callee_buf.clear();
289  if (num_output_group_ > 1) {
290  callee_buf
291  << " for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
292  << " result[i] = sum[i];\n"
293  << " }\n";
294  } else {
295  callee_buf << " return sum;\n";
296  }
297  callee_buf << " }\n" << "}\n";
298  CommitToFile(new_file, callee_buf.str());
299  }
300 
301  void HandleQNode(const QuantizerNode* node,
302  const std::string& dest,
303  int indent) {
304  std::ostringstream oss; // prepare a preamble
305  const int num_feature = node->is_categorical.size();
306  size_t length = 4;
307  oss << " private static final boolean[] is_categorical = {\n ";
308  for (int fid = 0; fid < num_feature; ++fid) {
309  if (node->is_categorical[fid]) {
310  common::WrapText(&oss, &length, "true, ", 4, 80);
311  } else {
312  common::WrapText(&oss, &length, "false, ", 4, 80);
313  }
314  }
315  oss << "\n };\n";
316  length = 4;
317  oss << " private static final float[] threshold = {\n ";
318  for (const auto& e : node->cut_pts) {
319  for (tl_float v : e) {
320  common::WrapText(&oss, &length, common::ToString(v) + "f, ", 4, 80);
321  }
322  }
323  oss << "\n };\n";
324  length = 4;
325  size_t accum = 0;
326  oss << " private static final int[] th_begin = {\n ";
327  for (const auto& e : node->cut_pts) {
328  common::WrapText(&oss, &length, std::to_string(accum) + ", ", 4, 80);
329  accum += e.size();
330  }
331  oss << "\n };\n";
332  length = 4;
333  oss << " private static final int[] th_len = {\n ";
334  for (const auto& e : node->cut_pts) {
335  common::WrapText(&oss, &length, std::to_string(e.size()) + ", ", 4, 80);
336  }
337  oss << "\n };\n";
338  #include "./java/quantize_func.h"
339  oss << quantize_func;
340  main_tail_ += oss.str();
341  oss.str(""); oss.clear();
342  oss << std::string(indent, ' ')
343  << "for (int i = 0; i < " << num_feature << "; ++i) {\n"
344  << std::string(indent + 2, ' ')
345  << "if (data[i].missing.get() != -1 && !is_categorical[i]) {\n"
346  << std::string(indent + 4, ' ')
347  << "data[i].qvalue.set(quantize(data[i].fvalue.get(), i));\n"
348  << std::string(indent + 2, ' ') + "}\n"
349  << std::string(indent, ' ') + "}\n";
350  CommitToFile(dest, oss.str());
351  CHECK_EQ(node->children.size(), 1);
352  WalkAST(node->children[0], dest, indent);
353  }
354 
355  inline std::vector<uint64_t>
356  to_bitmap(const std::vector<uint32_t>& left_categories) const {
357  const size_t num_left_categories = left_categories.size();
358  const uint32_t max_left_category = left_categories[num_left_categories - 1];
359  std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
360  for (size_t i = 0; i < left_categories.size(); ++i) {
361  const uint32_t cat = left_categories[i];
362  const size_t idx = cat / 64;
363  const uint32_t offset = cat % 64;
364  bitmap[idx] |= (static_cast<uint64_t>(1) << offset);
365  }
366  return bitmap;
367  }
368 };
369 
371 .describe("AST-based compiler that produces Java code")
372 .set_body([](const CompilerParam& param) -> Compiler* {
373  return new ASTJavaCompiler(param);
374  });
375 } // namespace compiler
376 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
parameters for tree compiler
Definition: param.h:16
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:32
Interface of compiler that compiles a tree ensemble model.
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:70
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
int verbose
if >0, produce extra messages
Definition: param.h:34
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_java.cc:23