treelite
failsafe.cc
Go to the documentation of this file.
1 
9 #include <unordered_map>
10 #include <set>
11 #include <tuple>
12 #include <utility>
13 #include <cmath>
14 #include <treelite/tree.h>
15 #include <treelite/compiler.h>
16 #include <treelite/common.h>
17 #include <fmt/format.h>
18 #include "./param.h"
19 #include "./pred_transform.h"
20 #include "./elf/elf_formatter.h"
21 
22 #if defined(_MSC_VER) || defined(_WIN32)
23 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
24 #else
25 #define DLLEXPORT_KEYWORD ""
26 #endif
27 
28 using namespace fmt::literals;
29 
30 namespace {
31 
32 struct NodeStructValue {
33  unsigned int sindex;
34  float info;
35  int cleft;
36  int cright;
37 };
38 
39 const char* header_template = R"TREELITETEMPLATE(
40 #include <stdlib.h>
41 #include <string.h>
42 #include <math.h>
43 #include <stdint.h>
44 
45 union Entry {{
46  int missing;
47  float fvalue;
48 }};
49 
50 union NodeInfo {{
51  float leaf_value;
52  float threshold;
53 }};
54 
55 struct Node {{
56  unsigned int sindex;
57  union NodeInfo info;
58  int cleft;
59  int cright;
60 }};
61 
62 extern const struct Node nodes[];
63 extern const int nodes_row_ptr[];
64 
65 {dllexport}size_t get_num_output_group(void);
66 {dllexport}size_t get_num_feature(void);
67 {dllexport}{predict_function_signature};
68 )TREELITETEMPLATE";
69 
70 const char* main_template = R"TREELITETEMPLATE(
71 #include "header.h"
72 
73 {nodes_row_ptr}
74 
75 size_t get_num_output_group(void) {{
76  return {num_output_group};
77 }}
78 
79 size_t get_num_feature(void) {{
80  return {num_feature};
81 }}
82 
83 {pred_transform_function}
84 
85 {predict_function_signature} {{
86  {accumulator_definition};
87 
88  for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{
89  int nid = 0;
90  const struct Node* tree = &nodes[nodes_row_ptr[tree_id]];
91  while (tree[nid].cleft != -1) {{
92  const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U);
93  const unsigned char default_left = (tree[nid].sindex >> 31) != 0;
94  if (data[feature_id].missing == -1) {{
95  nid = (default_left ? tree[nid].cleft : tree[nid].cright);
96  }} else {{
97  nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold
98  ? tree[nid].cleft : tree[nid].cright);
99  }}
100  }}
101  {output_statement}
102  }}
103  {return_statement}
104 }}
105 )TREELITETEMPLATE";
106 
107 const char* return_multiclass_template =
108 R"TREELITETEMPLATE(
109  for (int i = 0; i < {num_output_group}; ++i) {{
110  result[i] = sum[i] + (float)({global_bias});
111  }}
112  if (!pred_margin) {{
113  return pred_transform(result);
114  }} else {{
115  return {num_output_group};
116  }}
117 )TREELITETEMPLATE"; // only for multiclass classification
118 
119 const char* return_template =
120 R"TREELITETEMPLATE(
121  sum += (float)({global_bias});
122  if (!pred_margin) {{
123  return pred_transform(sum);
124  }} else {{
125  return sum;
126  }}
127 )TREELITETEMPLATE";
128 
129 const char* arrays_template = R"TREELITETEMPLATE(
130 #include "header.h"
131 
132 {nodes}
133 )TREELITETEMPLATE";
134 
135 // Returns formatted nodes[] and nodes_row_ptr[] arrays
136 // nodes[]: stores nodes from all decision trees
137 // nodes_row_ptr[]: marks bounaries between decision trees. The nodes belonging to Tree [i] are
138 // found in nodes[nodes_row_ptr[i]:nodes_row_ptr[i+1]]
139 inline std::pair<std::string, std::string> FormatNodesArray(const treelite::Model& model) {
140  treelite::common::ArrayFormatter nodes(100, 2);
141  treelite::common::ArrayFormatter nodes_row_ptr(100, 2);
142  int node_count = 0;
143  nodes_row_ptr << "0";
144  for (const auto& tree : model.trees) {
145  for (int nid = 0; nid < tree.num_nodes; ++nid) {
146  const auto& node = tree[nid];
147  if (node.is_leaf()) {
148  CHECK(!node.has_leaf_vector())
149  << "multi-class random forest classifier is not supported in FailSafeCompiler";
150  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
151  "sindex"_a = 0,
152  "info"_a = treelite::common::ToStringHighPrecision(node.leaf_value()),
153  "cleft"_a = -1,
154  "cright"_a = -1);
155  } else {
156  CHECK(node.split_type() == treelite::SplitFeatureType::kNumerical
157  && node.left_categories().empty())
158  << "categorical splits are not supported in FailSafeCompiler";
159  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
160  "sindex"_a = (node.split_index() | (static_cast<uint32_t>(node.default_left()) << 31)),
161  "info"_a = treelite::common::ToStringHighPrecision(node.threshold()),
162  "cleft"_a = node.cleft(),
163  "cright"_a = node.cright());
164  }
165  }
166  node_count += tree.num_nodes;
167  nodes_row_ptr << std::to_string(node_count);
168  }
169  return std::make_pair(fmt::format("const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
170  fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
171  nodes_row_ptr.str()));
172 }
173 
174 // Variant of FormatNodesArray(), where nodes[] array is dumped as an ELF binary
175 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(const treelite::Model& model) {
176  std::vector<char> nodes_elf;
178 
179  treelite::common::ArrayFormatter nodes_row_ptr(100, 2);
180  NodeStructValue val;
181  int node_count = 0;
182  nodes_row_ptr << "0";
183  for (const auto& tree : model.trees) {
184  for (int nid = 0; nid < tree.num_nodes; ++nid) {
185  const auto& node = tree[nid];
186  if (node.is_leaf()) {
187  CHECK(!node.has_leaf_vector())
188  << "multi-class random forest classifier is not supported in FailSafeCompiler";
189  val = {0, static_cast<float>(node.leaf_value()), -1, -1};
190  } else {
191  CHECK(node.split_type() == treelite::SplitFeatureType::kNumerical
192  && node.left_categories().empty())
193  << "categorical splits are not supported in FailSafeCompiler";
194  val = {(node.split_index() | (static_cast<uint32_t>(node.default_left()) << 31)),
195  static_cast<float>(node.threshold()), node.cleft(), node.cright()};
196  }
197  const size_t beg = nodes_elf.size();
198  nodes_elf.resize(beg + sizeof(NodeStructValue));
199  std::memcpy(&nodes_elf[beg], &val, sizeof(NodeStructValue));
200  }
201  node_count += tree.num_nodes;
202  nodes_row_ptr << std::to_string(node_count);
203  }
205 
206  return std::make_pair(nodes_elf, fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
207  nodes_row_ptr.str()));
208 }
209 
210 // Get the comparison op used in the tree ensemble model
211 // If splits have more than one op, throw an error
212 inline std::string GetCommonOp(const treelite::Model& model) {
213  std::set<treelite::Operator> ops;
214  for (const auto& tree : model.trees) {
215  for (int nid = 0; nid < tree.num_nodes; ++nid) {
216  const auto& node = tree[nid];
217  if (!node.is_leaf()) {
218  ops.insert(node.comparison_op());
219  }
220  }
221  }
222  // sanity check: all numerical splits must have identical comparison operators
223  CHECK_EQ(ops.size(), 1)
224  << "FailSafeCompiler only supports models where all splits use identical comparison operator.";
225  return treelite::OpName(*ops.begin());
226 }
227 
228 
229 // Test whether a string ends with a given suffix
230 inline bool EndsWith(const std::string& str, const std::string& suffix) {
231  return (str.size() >= suffix.size()
232  && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
233 }
234 
235 } // anonymous namespace
236 
237 namespace treelite {
238 namespace compiler {
239 
240 DMLC_REGISTRY_FILE_TAG(failsafe);
241 
242 class FailSafeCompiler : public Compiler {
243  public:
244  explicit FailSafeCompiler(const CompilerParam& param)
245  : param(param) {
246  if (param.verbose > 0) {
247  LOG(INFO) << "Using FailSafeCompiler";
248  }
249  if (param.annotate_in != "NULL") {
250  LOG(INFO) << "Warning: 'annotate_in' parameter is not applicable for "
251  "FailSafeCompiler";
252  }
253  if (param.quantize > 0) {
254  LOG(INFO) << "Warning: 'quantize' parameter is not applicable for "
255  "FailSafeCompiler";
256  }
257  if (param.parallel_comp > 0) {
258  LOG(INFO) << "Warning: 'parallel_comp' parameter is not applicable for "
259  "FailSafeCompiler";
260  }
261  if (std::isfinite(param.code_folding_req)) {
262  LOG(INFO) << "Warning: 'code_folding_req' parameter is not applicable "
263  "for FailSafeCompiler";
264  }
265  }
266 
267  CompiledModel Compile(const Model& model) override {
268  CompiledModel cm;
269  cm.backend = "native";
270 
271  num_feature_ = model.num_feature;
272  num_output_group_ = model.num_output_group;
273  CHECK(!model.random_forest_flag)
274  << "Only gradient boosted trees supported in FailSafeCompiler";
275  pred_tranform_func_ = PredTransformFunction("native", model);
276  files_.clear();
277 
278  const char* predict_function_signature
279  = (num_output_group_ > 1) ?
280  "size_t predict_multiclass(union Entry* data, int pred_margin, "
281  "float* result)"
282  : "float predict(union Entry* data, int pred_margin)";
283 
284  std::ostringstream main_program;
285  std::string accumulator_definition
286  = (num_output_group_ > 1
287  ? fmt::format("float sum[{num_output_group}] = {{0.0f}}",
288  "num_output_group"_a = num_output_group_)
289  : std::string("float sum = 0.0f"));
290 
291  std::string output_statement
292  = (num_output_group_ > 1
293  ? fmt::format("sum[tree_id % {num_output_group}] += tree[nid].info.leaf_value;",
294  "num_output_group"_a = num_output_group_)
295  : std::string("sum += tree[nid].info.leaf_value;"));
296 
297  std::string return_statement
298  = (num_output_group_ > 1
299  ? fmt::format(return_multiclass_template,
300  "num_output_group"_a = num_output_group_,
301  "global_bias"_a = common::ToStringHighPrecision(model.param.global_bias))
302  : fmt::format(return_template,
303  "global_bias"_a = common::ToStringHighPrecision(model.param.global_bias)));
304 
305  std::string nodes, nodes_row_ptr;
306  std::vector<char> nodes_elf;
307  if (param.dump_array_as_elf > 0) {
308  if (param.verbose > 0) {
309  LOG(INFO) << "Dumping arrays as an ELF relocatable object...";
310  }
311  std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
312  } else {
313  std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
314  }
315 
316  main_program << fmt::format(main_template,
317  "nodes_row_ptr"_a = nodes_row_ptr,
318  "pred_transform_function"_a = pred_tranform_func_,
319  "predict_function_signature"_a = predict_function_signature,
320  "num_output_group"_a = num_output_group_,
321  "num_feature"_a = num_feature_,
322  "num_tree"_a = model.trees.size(),
323  "compare_op"_a = GetCommonOp(model),
324  "accumulator_definition"_a = accumulator_definition,
325  "output_statement"_a = output_statement,
326  "return_statement"_a = return_statement);
327 
328  files_["main.c"] = CompiledModel::FileEntry(main_program.str());
329 
330  if (param.dump_array_as_elf > 0) {
331  files_["arrays.o"] = CompiledModel::FileEntry(std::move(nodes_elf));
332  } else {
333  files_["arrays.c"] = CompiledModel::FileEntry(fmt::format(arrays_template,
334  "nodes"_a = nodes));
335  }
336 
337  files_["header.h"] = CompiledModel::FileEntry(fmt::format(header_template,
338  "dllexport"_a = DLLEXPORT_KEYWORD,
339  "predict_function_signature"_a = predict_function_signature));
340 
341  {
342  /* write recipe.json */
343  std::vector<std::unordered_map<std::string, std::string>> source_list;
344  std::vector<std::string> extra_file_list;
345  for (const auto& kv : files_) {
346  if (EndsWith(kv.first, ".c")) {
347  const size_t line_count
348  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
349  source_list.push_back({ {"name",
350  kv.first.substr(0, kv.first.length() - 2)},
351  {"length", std::to_string(line_count)} });
352  } else if (EndsWith(kv.first, ".o")) {
353  extra_file_list.push_back(kv.first);
354  }
355  }
356  std::ostringstream oss;
357  auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
358  writer->BeginObject();
359  writer->WriteObjectKeyValue("target", param.native_lib_name);
360  writer->WriteObjectKeyValue("sources", source_list);
361  if (!extra_file_list.empty()) {
362  writer->WriteObjectKeyValue("extra", extra_file_list);
363  }
364  writer->EndObject();
365  files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
366  }
367  cm.files = std::move(files_);
368  return cm;
369  }
370 
371  private:
372  CompilerParam param;
373  int num_feature_;
374  int num_output_group_;
375  std::string pred_tranform_func_;
376  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
377 };
378 
380 .describe("Simple compiler to express trees as a tight for-loop")
381 .set_body([](const CompilerParam& param) -> Compiler* {
382  return new FailSafeCompiler(param);
383  });
384 } // namespace compiler
385 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:438
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: failsafe.cc:267
thin wrapper for tree ensemble model
Definition: tree.h:428
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
std::vector< Tree > trees
member trees
Definition: tree.h:430
parameters for tree compiler
Definition: param.h:18
ModelParam param
extra parameters
Definition: tree.h:443
model structure for tree
void AllocateELFHeader(std::vector< char > *elf_buffer)
Pre-allocate space in a buffer to fit an ELF header.
float global_bias
global bias of the model
Definition: tree.h:399
Parameters for tree compiler.
interface of compiler
Definition: compiler.h:55
Interface of compiler that compiles a tree ensemble model.
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:126
std::string native_lib_name
native lib name (without extension)
Definition: param.h:38
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
Definition: param.h:26
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
Definition: param.h:49
Definition: compiler.h:28
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:93
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:441
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
Definition: param.h:44
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
Definition: param.h:34
int quantize
whether to quantize threshold points (0: no, >0: yes)
Definition: param.h:28
int verbose
if >0, produce extra messages
Definition: param.h:36
Generate a relocatable object file containing a constant, read-only array.
void FormatArrayAsELF(std::vector< char > *elf_buffer)
Format a relocatable ELF object file containing a constant, read-only array.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:435