Treelite
failsafe.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/compiler.h>
12 #include <fmt/format.h>
13 #include <unordered_map>
14 #include <set>
15 #include <tuple>
16 #include <utility>
17 #include <cmath>
18 #include "./pred_transform.h"
19 #include "./common/format_util.h"
20 #include "./elf/elf_formatter.h"
21 
22 #if defined(_MSC_VER) || defined(_WIN32)
23 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
24 #else
25 #define DLLEXPORT_KEYWORD ""
26 #endif
27 
28 using namespace fmt::literals;
29 
30 namespace {
31 
32 struct NodeStructValue {
33  unsigned int sindex;
34  float info;
35  int cleft;
36  int cright;
37 };
38 
39 const char* header_template = R"TREELITETEMPLATE(
40 #include <stdlib.h>
41 #include <string.h>
42 #include <math.h>
43 #include <stdint.h>
44 
45 union Entry {{
46  int missing;
47  float fvalue;
48 }};
49 
50 union NodeInfo {{
51  float leaf_value;
52  float threshold;
53 }};
54 
55 struct Node {{
56  unsigned int sindex;
57  union NodeInfo info;
58  int cleft;
59  int cright;
60 }};
61 
62 extern const struct Node nodes[];
63 extern const int nodes_row_ptr[];
64 
65 {dllexport}size_t get_num_output_group(void);
66 {dllexport}size_t get_num_feature(void);
67 {dllexport}{predict_function_signature};
68 )TREELITETEMPLATE";
69 
70 const char* main_template = R"TREELITETEMPLATE(
71 #include "header.h"
72 
73 {nodes_row_ptr}
74 
75 size_t get_num_output_group(void) {{
76  return {num_output_group};
77 }}
78 
79 size_t get_num_feature(void) {{
80  return {num_feature};
81 }}
82 
83 {pred_transform_function}
84 
85 {predict_function_signature} {{
86  {accumulator_definition};
87 
88  for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{
89  int nid = 0;
90  const struct Node* tree = &nodes[nodes_row_ptr[tree_id]];
91  while (tree[nid].cleft != -1) {{
92  const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U);
93  const unsigned char default_left = (tree[nid].sindex >> 31) != 0;
94  if (data[feature_id].missing == -1) {{
95  nid = (default_left ? tree[nid].cleft : tree[nid].cright);
96  }} else {{
97  nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold
98  ? tree[nid].cleft : tree[nid].cright);
99  }}
100  }}
101  {output_statement}
102  }}
103  {return_statement}
104 }}
105 )TREELITETEMPLATE";
106 
107 const char* return_multiclass_template =
108 R"TREELITETEMPLATE(
109  for (int i = 0; i < {num_output_group}; ++i) {{
110  result[i] = sum[i] + (float)({global_bias});
111  }}
112  if (!pred_margin) {{
113  return pred_transform(result);
114  }} else {{
115  return {num_output_group};
116  }}
117 )TREELITETEMPLATE"; // only for multiclass classification
118 
119 const char* return_template =
120 R"TREELITETEMPLATE(
121  sum += (float)({global_bias});
122  if (!pred_margin) {{
123  return pred_transform(sum);
124  }} else {{
125  return sum;
126  }}
127 )TREELITETEMPLATE";
128 
129 const char* arrays_template = R"TREELITETEMPLATE(
130 #include "header.h"
131 
132 {nodes}
133 )TREELITETEMPLATE";
134 
135 // Returns formatted nodes[] and nodes_row_ptr[] arrays
136 // nodes[]: stores nodes from all decision trees
137 // nodes_row_ptr[]: marks bounaries between decision trees. The nodes belonging to Tree [i] are
138 // found in nodes[nodes_row_ptr[i]:nodes_row_ptr[i+1]]
139 inline std::pair<std::string, std::string> FormatNodesArray(const treelite::Model& model) {
142  int node_count = 0;
143  nodes_row_ptr << "0";
144  for (const auto& tree : model.trees) {
145  for (int nid = 0; nid < tree.num_nodes; ++nid) {
146  if (tree.IsLeaf(nid)) {
147  CHECK(!tree.HasLeafVector(nid))
148  << "multi-class random forest classifier is not supported in FailSafeCompiler";
149  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
150  "sindex"_a = 0,
151  "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.LeafValue(nid)),
152  "cleft"_a = -1,
153  "cright"_a = -1);
154  } else {
155  CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
156  && tree.LeftCategories(nid).empty())
157  << "categorical splits are not supported in FailSafeCompiler";
158  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
159  "sindex"_a
160  = (tree.SplitIndex(nid) |(static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31U)),
161  "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.Threshold(nid)),
162  "cleft"_a = tree.LeftChild(nid),
163  "cright"_a = tree.RightChild(nid));
164  }
165  }
166  node_count += tree.num_nodes;
167  nodes_row_ptr << std::to_string(node_count);
168  }
169  return std::make_pair(fmt::format("const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
170  fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
171  nodes_row_ptr.str()));
172 }
173 
174 // Variant of FormatNodesArray(), where nodes[] array is dumped as an ELF binary
175 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(const treelite::Model& model) {
176  std::vector<char> nodes_elf;
178 
180  NodeStructValue val;
181  int node_count = 0;
182  nodes_row_ptr << "0";
183  for (const auto& tree : model.trees) {
184  for (int nid = 0; nid < tree.num_nodes; ++nid) {
185  if (tree.IsLeaf(nid)) {
186  CHECK(!tree.HasLeafVector(nid))
187  << "multi-class random forest classifier is not supported in FailSafeCompiler";
188  val = {0, static_cast<float>(tree.LeafValue(nid)), -1, -1};
189  } else {
190  CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
191  && tree.LeftCategories(nid).empty())
192  << "categorical splits are not supported in FailSafeCompiler";
193  val = {(tree.SplitIndex(nid) | (static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31)),
194  static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
195  }
196  const size_t beg = nodes_elf.size();
197  nodes_elf.resize(beg + sizeof(NodeStructValue));
198  std::memcpy(&nodes_elf[beg], &val, sizeof(NodeStructValue));
199  }
200  node_count += tree.num_nodes;
201  nodes_row_ptr << std::to_string(node_count);
202  }
204 
205  return std::make_pair(nodes_elf, fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
206  nodes_row_ptr.str()));
207 }
208 
209 // Get the comparison op used in the tree ensemble model
210 // If splits have more than one op, throw an error
211 inline std::string GetCommonOp(const treelite::Model& model) {
212  std::set<treelite::Operator> ops;
213  for (const auto& tree : model.trees) {
214  for (int nid = 0; nid < tree.num_nodes; ++nid) {
215  if (!tree.IsLeaf(nid)) {
216  ops.insert(tree.ComparisonOp(nid));
217  }
218  }
219  }
220  // sanity check: all numerical splits must have identical comparison operators
221  CHECK_EQ(ops.size(), 1)
222  << "FailSafeCompiler only supports models where all splits use identical comparison operator.";
223  return treelite::OpName(*ops.begin());
224 }
225 
226 
227 // Test whether a string ends with a given suffix
228 inline bool EndsWith(const std::string& str, const std::string& suffix) {
229  return (str.size() >= suffix.size()
230  && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
231 }
232 
233 } // anonymous namespace
234 
235 namespace treelite {
236 namespace compiler {
237 
238 DMLC_REGISTRY_FILE_TAG(failsafe);
239 
240 class FailSafeCompiler : public Compiler {
241  public:
242  explicit FailSafeCompiler(const CompilerParam& param)
243  : param(param) {
244  if (param.verbose > 0) {
245  LOG(INFO) << "Using FailSafeCompiler";
246  }
247  if (param.annotate_in != "NULL") {
248  LOG(INFO) << "Warning: 'annotate_in' parameter is not applicable for "
249  "FailSafeCompiler";
250  }
251  if (param.quantize > 0) {
252  LOG(INFO) << "Warning: 'quantize' parameter is not applicable for "
253  "FailSafeCompiler";
254  }
255  if (param.parallel_comp > 0) {
256  LOG(INFO) << "Warning: 'parallel_comp' parameter is not applicable for "
257  "FailSafeCompiler";
258  }
259  if (std::isfinite(param.code_folding_req)) {
260  LOG(INFO) << "Warning: 'code_folding_req' parameter is not applicable "
261  "for FailSafeCompiler";
262  }
263  }
264 
265  CompiledModel Compile(const Model& model) override {
266  CompiledModel cm;
267  cm.backend = "native";
268 
269  num_feature_ = model.num_feature;
270  num_output_group_ = model.num_output_group;
271  CHECK(!model.random_forest_flag)
272  << "Only gradient boosted trees supported in FailSafeCompiler";
273  pred_tranform_func_ = PredTransformFunction("native", model);
274  files_.clear();
275 
276  const char* predict_function_signature
277  = (num_output_group_ > 1) ?
278  "size_t predict_multiclass(union Entry* data, int pred_margin, "
279  "float* result)"
280  : "float predict(union Entry* data, int pred_margin)";
281 
282  std::ostringstream main_program;
283  std::string accumulator_definition
284  = (num_output_group_ > 1
285  ? fmt::format("float sum[{num_output_group}] = {{0.0f}}",
286  "num_output_group"_a = num_output_group_)
287  : std::string("float sum = 0.0f"));
288 
289  std::string output_statement
290  = (num_output_group_ > 1
291  ? fmt::format("sum[tree_id % {num_output_group}] += tree[nid].info.leaf_value;",
292  "num_output_group"_a = num_output_group_)
293  : std::string("sum += tree[nid].info.leaf_value;"));
294 
295  std::string return_statement
296  = (num_output_group_ > 1
297  ? fmt::format(return_multiclass_template,
298  "num_output_group"_a = num_output_group_,
299  "global_bias"_a
300  = compiler::common_util::ToStringHighPrecision(model.param.global_bias))
301  : fmt::format(return_template,
302  "global_bias"_a
303  = compiler::common_util::ToStringHighPrecision(model.param.global_bias)));
304 
305  std::string nodes, nodes_row_ptr;
306  std::vector<char> nodes_elf;
307  if (param.dump_array_as_elf > 0) {
308  if (param.verbose > 0) {
309  LOG(INFO) << "Dumping arrays as an ELF relocatable object...";
310  }
311  std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
312  } else {
313  std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
314  }
315 
316  main_program << fmt::format(main_template,
317  "nodes_row_ptr"_a = nodes_row_ptr,
318  "pred_transform_function"_a = pred_tranform_func_,
319  "predict_function_signature"_a = predict_function_signature,
320  "num_output_group"_a = num_output_group_,
321  "num_feature"_a = num_feature_,
322  "num_tree"_a = model.trees.size(),
323  "compare_op"_a = GetCommonOp(model),
324  "accumulator_definition"_a = accumulator_definition,
325  "output_statement"_a = output_statement,
326  "return_statement"_a = return_statement);
327 
328  files_["main.c"] = CompiledModel::FileEntry(main_program.str());
329 
330  if (param.dump_array_as_elf > 0) {
331  files_["arrays.o"] = CompiledModel::FileEntry(std::move(nodes_elf));
332  } else {
333  files_["arrays.c"] = CompiledModel::FileEntry(fmt::format(arrays_template,
334  "nodes"_a = nodes));
335  }
336 
337  files_["header.h"] = CompiledModel::FileEntry(fmt::format(header_template,
338  "dllexport"_a = DLLEXPORT_KEYWORD,
339  "predict_function_signature"_a = predict_function_signature));
340 
341  {
342  /* write recipe.json */
343  std::vector<std::unordered_map<std::string, std::string>> source_list;
344  std::vector<std::string> extra_file_list;
345  for (const auto& kv : files_) {
346  if (EndsWith(kv.first, ".c")) {
347  const size_t line_count
348  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
349  source_list.push_back({ {"name",
350  kv.first.substr(0, kv.first.length() - 2)},
351  {"length", std::to_string(line_count)} });
352  } else if (EndsWith(kv.first, ".o")) {
353  extra_file_list.push_back(kv.first);
354  }
355  }
356  std::ostringstream oss;
357  std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&oss));
358  writer->BeginObject();
359  writer->WriteObjectKeyValue("target", param.native_lib_name);
360  writer->WriteObjectKeyValue("sources", source_list);
361  if (!extra_file_list.empty()) {
362  writer->WriteObjectKeyValue("extra", extra_file_list);
363  }
364  writer->EndObject();
365  files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
366  }
367  cm.files = std::move(files_);
368  return cm;
369  }
370 
371  private:
372  CompilerParam param;
373  int num_feature_;
374  int num_output_group_;
375  std::string pred_tranform_func_;
376  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
377 };
378 
380 .describe("Simple compiler to express trees as a tight for-loop")
381 .set_body([](const CompilerParam& param) -> Compiler* {
382  return new FailSafeCompiler(param);
383  });
384 } // namespace compiler
385 } // namespace treelite
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:419
Parameters for tree compiler.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
Definition: failsafe.cc:265
thin wrapper for tree ensemble model
Definition: tree.h:409
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:40
std::vector< Tree > trees
member trees
Definition: tree.h:411
parameters for tree compiler
ModelParam param
extra parameters
Definition: tree.h:424
model structure for tree ensemble
void AllocateELFHeader(std::vector< char > *elf_buffer)
Pre-allocate space in a buffer to fit an ELF header.
float global_bias
global bias of the model
Definition: tree.h:383
interface of compiler
Definition: compiler.h:54
Interface of compiler that compiles a tree ensemble model.
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Definition: compiler.h:27
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:92
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:422
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
Generate a relocatable object file containing a constant, read-only array.
void FormatArrayAsELF(std::vector< char > *elf_buffer)
Format a relocatable ELF object file containing a constant, read-only array.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:416