treelite
ast_java.cc
1 #include <treelite/compiler.h>
2 #include <treelite/common.h>
3 #include <unordered_map>
4 #include <cmath>
5 #include "./param.h"
6 #include "./pred_transform.h"
7 #include "./ast/builder.h"
8 #include "./java/get_num_feature.h"
9 #include "./java/get_num_output_group.h"
10 
11 namespace treelite {
12 namespace compiler {
13 
14 DMLC_REGISTRY_FILE_TAG(ast_java);
15 
16 class ASTJavaCompiler : public Compiler {
17  public:
18  explicit ASTJavaCompiler(const CompilerParam& param)
19  : param(param) {
20  if (param.verbose > 0) {
21  LOG(INFO) << "Using ASTJavaCompiler";
22  }
23  }
24 
25  CompiledModel Compile(const Model& model) override {
26  CompiledModel cm;
27  cm.backend = "java";
28  cm.files["Main.java"] = "";
29 
30  if (param.annotate_in != "NULL") {
31  LOG(WARNING) << "\033[1;31mWARNING: Java backend does not support "
32  << "branch annotation.\u001B[0m The parameter "
33  << "\x1B[33mannotate_in\u001B[0m will be ignored.";
34  }
35  num_output_group_ = model.num_output_group;
36  pred_tranform_func_ = PredTransformFunction("java", model);
37  files_.clear();
38 
39  ASTBuilder builder;
40  builder.Build(model);
41  builder.Split(param.parallel_comp);
42  if (param.quantize > 0) {
43  builder.QuantizeThresholds();
44  }
45  builder.CountDescendant();
46  builder.BreakUpLargeUnits(param.max_unit_size);
47  #include "java/entry_type.h"
48  #include "java/pom_xml.h"
49  files_[file_prefix_ + "Entry.java"] = entry_type;
50  files_["pom.xml"] = pom_xml;
51  WalkAST(builder.GetRootNode(), "Main.java", 0);
52 
53  cm.files = std::move(files_);
54  cm.file_prefix = file_prefix_;
55  return cm;
56  }
57  private:
58  CompilerParam param;
59  int num_output_group_;
60  std::string pred_tranform_func_;
61  std::unordered_map<std::string, std::string> files_;
62  std::string main_tail_;
63  const std::string file_prefix_ = "src/main/java/treelite/predictor/";
64 
65  void WalkAST(const ASTNode* node,
66  const std::string& dest,
67  int indent) {
68  const MainNode* t1;
69  const AccumulatorContextNode* t2;
70  const ConditionNode* t3;
71  const OutputNode* t4;
72  const TranslationUnitNode* t5;
73  const QuantizerNode* t6;
74  if ( (t1 = dynamic_cast<const MainNode*>(node)) ) {
75  HandleMainNode(t1, dest, indent);
76  } else if ( (t2 = dynamic_cast<const AccumulatorContextNode*>(node)) ) {
77  HandleACNode(t2, dest, indent);
78  } else if ( (t3 = dynamic_cast<const ConditionNode*>(node)) ) {
79  HandleCondNode(t3, dest, indent);
80  } else if ( (t4 = dynamic_cast<const OutputNode*>(node)) ) {
81  HandleOutputNode(t4, dest, indent);
82  } else if ( (t5 = dynamic_cast<const TranslationUnitNode*>(node)) ) {
83  HandleTUNode(t5, dest, indent);
84  } else if ( (t6 = dynamic_cast<const QuantizerNode*>(node)) ) {
85  HandleQNode(t6, dest, indent);
86  } else {
87  LOG(FATAL) << "oops";
88  }
89  }
90 
91  void CommitToFile(const std::string& dest,
92  const std::string& content) {
93  files_[file_prefix_ + dest] += content;
94  }
95 
96  void HandleMainNode(const MainNode* node,
97  const std::string& dest,
98  int indent) {
99  const std::string prototype
100  = (num_output_group_ > 1) ?
101  "public static long predict_multiclass(Entry[] data, "
102  "boolean pred_margin, "
103  "float[] result)"
104  : "public static float predict(Entry[] data, boolean pred_margin)";
105  CommitToFile(dest,
106  "package treelite.predictor;\n\n"
107  "import java.lang.Math;\n"
108  "import javolution.context.LogContext;\n"
109  "import javolution.context.LogContext.Level;\n\n"
110  "public class Main {\n");
111  CommitToFile(dest,
112  " static {\n LogContext ctx = LogContext.enter();\n"
113  " ctx.setLevel(Level.INFO);\n }\n");
114  CommitToFile(dest,
115  get_num_output_group_func(num_output_group_) + "\n"
116  + get_num_feature_func(node->num_feature) + "\n"
117  + pred_tranform_func_ + "\n"
118  + std::string(indent + 2, ' ') + prototype + " {\n");
119  CHECK_EQ(node->children.size(), 1);
120  WalkAST(node->children[0], dest, indent + 4);
121  std::ostringstream oss;
122  if (num_output_group_ > 1) {
123  oss << std::string(indent + 4, ' ')
124  << "for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
125  << std::string(indent + 6, ' ') << "result[i] = sum[i]";
126  if (node->average_result) {
127  oss << " / " << node->num_tree;
128  }
129  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
130  << std::string(indent + 4, ' ') << "}\n"
131  << std::string(indent + 4, ' ') << "if (!pred_margin) {\n"
132  << std::string(indent + 4, ' ')
133  << " return pred_transform(result);\n"
134  << std::string(indent + 4, ' ') << "} else {\n"
135  << std::string(indent + 4, ' ')
136  << " return " << num_output_group_ << ";\n"
137  << std::string(indent + 4, ' ') << "}\n";
138  } else {
139  oss << std::string(indent + 4, ' ') << "sum = sum";
140  if (node->average_result) {
141  oss << " / " << node->num_tree;
142  }
143  oss << " + (float)(" << common::ToString(node->global_bias) << ");\n"
144  << std::string(indent + 4, ' ') << "if (!pred_margin) {\n"
145  << std::string(indent + 4, ' ') << " return pred_transform(sum);\n"
146  << std::string(indent + 4, ' ') << "} else {\n"
147  << std::string(indent + 4, ' ') << " return sum;\n"
148  << std::string(indent + 4, ' ') << "}\n";
149  }
150  oss << std::string(indent + 2, ' ') << "}\n"
151  << main_tail_
152  << std::string(indent, ' ') << "}\n";
153  CommitToFile(dest, oss.str());
154  }
155 
156  void HandleACNode(const AccumulatorContextNode* node,
157  const std::string& dest,
158  int indent) {
159  std::ostringstream oss;
160  if (num_output_group_ > 1) {
161  oss << std::string(indent, ' ')
162  << "float[] sum = new float[" << num_output_group_ << "];\n";
163  } else {
164  oss << std::string(indent, ' ') << "float sum = 0.0f;\n";
165  }
166  oss << std::string(indent, ' ') << "int tmp;\n";
167  CommitToFile(dest, oss.str());
168 
169  for (ASTNode* child : node->children) {
170  WalkAST(child, dest, indent);
171  }
172  }
173 
174  void HandleCondNode(const ConditionNode* node,
175  const std::string& dest,
176  int indent) {
177  const unsigned split_index = node->split_index;
178  const std::string na_check
179  = std::string("data[") + std::to_string(split_index)
180  + "].missing.get() != -1";
181 
182  const NumericalConditionNode* t;
183  std::ostringstream oss; // prepare logical statement for evaluating split
184  if ( (t = dynamic_cast<const NumericalConditionNode*>(node)) ) {
185  if (t->quantized) { // quantized threshold
186  oss << "data[" << split_index << "].qvalue.get() "
187  << OpName(t->op) << " " << t->threshold.int_val;
188  } else if (std::isinf(t->threshold.float_val)) { // infinite threshold
189  // According to IEEE 754, the result of comparison [lhs] < infinity
190  // must be identical for all finite [lhs]. Same goes for operator >.
191  oss << (common::CompareWithOp(0.0, t->op, t->threshold.float_val)
192  ? "true" : "false");
193  } else { // finite threshold
194  // to restore default precision
195  const std::streamsize ss = std::cout.precision();
196  oss << "data[" << split_index << "].fvalue.get() "
197  << OpName(t->op) << " "
198  << std::setprecision(std::numeric_limits<tl_float>::digits10 + 2)
199  << t->threshold.float_val << "f" << std::setprecision(ss);
200  }
201  } else { // categorical split
202  const CategoricalConditionNode* t2
203  = dynamic_cast<const CategoricalConditionNode*>(node);
204  CHECK(t2);
205  std::vector<uint64_t> bitmap = to_bitmap(t2->left_categories);
206  CHECK_GE(bitmap.size(), 1);
207  bool all_zeros = true;
208  for (uint64_t e : bitmap) {
209  all_zeros &= (e == 0);
210  }
211  if (all_zeros) {
212  oss << "0";
213  } else {
214  oss << "(tmp = (int)(data[" << split_index << "].fvalue.get()) ), "
215  << "(tmp >= 0 && tmp < 64 && (( (long)"
216  << bitmap[0] << "L >>> tmp) & 1) )";
217  for (size_t i = 1; i < bitmap.size(); ++i) {
218  oss << " || (tmp >= " << (i * 64)
219  << " && tmp < " << ((i + 1) * 64)
220  << " && (( (long)" << bitmap[i]
221  << "L >>> (tmp - " << (i * 64) << ") ) & 1) )";
222  }
223  }
224  }
225  CommitToFile(dest,
226  std::string(indent, ' ') + "if ("
227  + ((node->default_left) ? (std::string("!(") + na_check + ") || (")
228  : (std::string(" (") + na_check + ") && ("))
229  + oss.str() + ") ) {\n");
230  CHECK_EQ(node->children.size(), 2);
231  WalkAST(node->children[0], dest, indent + 2);
232  CommitToFile(dest, std::string(indent, ' ') + "} else {\n");
233  WalkAST(node->children[1], dest, indent + 2);
234  CommitToFile(dest, std::string(indent, ' ') + "}\n");
235  }
236 
237  void HandleOutputNode(const OutputNode* node,
238  const std::string& dest,
239  int indent) {
240  std::ostringstream oss;
241  if (num_output_group_ > 1) {
242  if (node->is_vector) {
243  // multi-class classification with random forest
244  const std::vector<tl_float>& leaf_vector = node->vector;
245  CHECK_EQ(leaf_vector.size(), static_cast<size_t>(num_output_group_))
246  << "Ill-formed model: leaf vector must be of length [num_output_group]";
247  for (int group_id = 0; group_id < num_output_group_; ++group_id) {
248  oss << std::string(indent, ' ')
249  << "sum[" << group_id << "] += "
250  << common::ToString(leaf_vector[group_id]) << "f;\n";
251  }
252  } else {
253  // multi-class classification with gradient boosted trees
254  oss << std::string(indent, ' ') << "sum["
255  << (node->tree_id % num_output_group_) << "] += "
256  << common::ToString(node->scalar) << "f;\n";
257  }
258  } else {
259  oss << std::string(indent, ' ') << "sum += "
260  << common::ToString(node->scalar) << "f;\n";
261  }
262  CommitToFile(dest, oss.str());
263  CHECK_EQ(node->children.size(), 0);
264  }
265 
266  void HandleTUNode(const TranslationUnitNode* node,
267  const std::string& dest,
268  int indent) {
269  const std::string new_file
270  = std::string("TU") + std::to_string(node->unit_id) + ".java";
271  std::ostringstream caller_buf, callee_buf, func_name, class_name, prototype;
272  class_name << "TU" << node->unit_id;
273  if (num_output_group_ > 1) {
274  func_name << "predict_margin_multiclass_unit" << node->unit_id;
275  caller_buf << std::string(indent, ' ')
276  << class_name.str() << "."
277  << func_name.str() << "(data, sum);\n";
278  prototype << "public static void " << func_name.str()
279  << "(Entry[] data, float[] result)";
280  } else {
281  func_name << "predict_margin_unit" << node->unit_id;
282  caller_buf << std::string(indent, ' ')
283  << "sum += " << class_name.str() << "." << func_name.str()
284  << "(data);\n";
285  prototype << "public static float " << func_name.str()
286  << "(Entry[] data)";
287  }
288  callee_buf << "package treelite.predictor;\n\n"
289  << "public class " << class_name.str() << " {\n"
290  << " " << prototype.str() << " {\n";
291  CommitToFile(dest, caller_buf.str());
292  CommitToFile(new_file, callee_buf.str());
293  CHECK_EQ(node->children.size(), 1);
294  WalkAST(node->children[0], new_file, 4);
295  callee_buf.str(""); callee_buf.clear();
296  if (num_output_group_ > 1) {
297  callee_buf
298  << " for (int i = 0; i < " << num_output_group_ << "; ++i) {\n"
299  << " result[i] = sum[i];\n"
300  << " }\n";
301  } else {
302  callee_buf << " return sum;\n";
303  }
304  callee_buf << " }\n" << "}\n";
305  CommitToFile(new_file, callee_buf.str());
306  }
307 
308  void HandleQNode(const QuantizerNode* node,
309  const std::string& dest,
310  int indent) {
311  std::ostringstream oss; // prepare a preamble
312  const int num_feature = node->is_categorical.size();
313  size_t length = 4;
314  oss << " private static final boolean[] is_categorical = {\n ";
315  for (int fid = 0; fid < num_feature; ++fid) {
316  if (node->is_categorical[fid]) {
317  common::WrapText(&oss, &length, "true, ", 4, 80);
318  } else {
319  common::WrapText(&oss, &length, "false, ", 4, 80);
320  }
321  }
322  oss << "\n };\n";
323  length = 4;
324  oss << " private static final float[] threshold = {\n ";
325  for (const auto& e : node->cut_pts) {
326  for (tl_float v : e) {
327  common::WrapText(&oss, &length, common::ToString(v) + "f, ", 4, 80);
328  }
329  }
330  oss << "\n };\n";
331  length = 4;
332  size_t accum = 0;
333  oss << " private static final int[] th_begin = {\n ";
334  for (const auto& e : node->cut_pts) {
335  common::WrapText(&oss, &length, std::to_string(accum) + ", ", 4, 80);
336  accum += e.size();
337  }
338  oss << "\n };\n";
339  length = 4;
340  oss << " private static final int[] th_len = {\n ";
341  for (const auto& e : node->cut_pts) {
342  common::WrapText(&oss, &length, std::to_string(e.size()) + ", ", 4, 80);
343  }
344  oss << "\n };\n";
345  #include "./java/quantize_func.h"
346  oss << quantize_func;
347  main_tail_ += oss.str();
348  oss.str(""); oss.clear();
349  oss << std::string(indent, ' ')
350  << "for (int i = 0; i < " << num_feature << "; ++i) {\n"
351  << std::string(indent + 2, ' ')
352  << "if (data[i].missing.get() != -1 && !is_categorical[i]) {\n"
353  << std::string(indent + 4, ' ')
354  << "data[i].qvalue.set(quantize(data[i].fvalue.get(), i));\n"
355  << std::string(indent + 2, ' ') + "}\n"
356  << std::string(indent, ' ') + "}\n";
357  CommitToFile(dest, oss.str());
358  CHECK_EQ(node->children.size(), 1);
359  WalkAST(node->children[0], dest, indent);
360  }
361 
362  inline std::vector<uint64_t>
363  to_bitmap(const std::vector<uint32_t>& left_categories) const {
364  const size_t num_left_categories = left_categories.size();
365  const uint32_t max_left_category = left_categories[num_left_categories - 1];
366  std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
367  for (size_t i = 0; i < left_categories.size(); ++i) {
368  const uint32_t cat = left_categories[i];
369  const size_t idx = cat / 64;
370  const uint32_t offset = cat % 64;
371  bitmap[idx] |= (static_cast<uint64_t>(1) << offset);
372  }
373  return bitmap;
374  }
375 };
376 
378 .describe("AST-based compiler that produces Java code")
379 .set_body([](const CompilerParam& param) -> Compiler* {
380  return new ASTJavaCompiler(param);
381  });
382 } // namespace compiler
383 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
parameters for tree compiler
Definition: param.h:16
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:32
Interface of compiler that compiles a tree ensemble model.
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:24
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:70
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:32
Some useful utilities.
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:26
int verbose
if >0, produce extra messages
Definition: param.h:34
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: ast_java.cc:25