Treelite
failsafe.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/compiler.h>
12 #include <fmt/format.h>
13 #include <unordered_map>
14 #include <set>
15 #include <tuple>
16 #include <utility>
17 #include <cmath>
18 #include "./pred_transform.h"
19 #include "./common/format_util.h"
20 #include "./elf/elf_formatter.h"
21 #include "./native/main_template.h"
23 
24 #if defined(_MSC_VER) || defined(_WIN32)
25 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
26 #else
27 #define DLLEXPORT_KEYWORD ""
28 #endif
29 
30 using namespace fmt::literals;
31 
32 namespace {
33 
34 struct NodeStructValue {
35  unsigned int sindex;
36  float info;
37  int cleft;
38  int cright;
39 };
40 
41 const char* const header_template = R"TREELITETEMPLATE(
42 #include <stdlib.h>
43 #include <string.h>
44 #include <math.h>
45 #include <stdint.h>
46 
47 union Entry {{
48  int missing;
49  float fvalue;
50 }};
51 
52 union NodeInfo {{
53  float leaf_value;
54  float threshold;
55 }};
56 
57 struct Node {{
58  unsigned int sindex;
59  union NodeInfo info;
60  int cleft;
61  int cright;
62 }};
63 
64 extern const struct Node nodes[];
65 extern const int nodes_row_ptr[];
66 
67 {query_functions_prototype}
68 {dllexport}{predict_function_signature};
69 )TREELITETEMPLATE";
70 
71 const char* const main_template = R"TREELITETEMPLATE(
72 #include "header.h"
73 
74 {nodes_row_ptr}
75 
76 {query_functions_definition}
77 
78 {pred_transform_function}
79 
80 {predict_function_signature} {{
81  {accumulator_definition};
82 
83  for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{
84  int nid = 0;
85  const struct Node* tree = &nodes[nodes_row_ptr[tree_id]];
86  while (tree[nid].cleft != -1) {{
87  const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U);
88  const unsigned char default_left = (tree[nid].sindex >> 31) != 0;
89  if (data[feature_id].missing == -1) {{
90  nid = (default_left ? tree[nid].cleft : tree[nid].cright);
91  }} else {{
92  nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold
93  ? tree[nid].cleft : tree[nid].cright);
94  }}
95  }}
96  {output_statement}
97  }}
98  {return_statement}
99 }}
100 )TREELITETEMPLATE";
101 
102 const char* const return_multiclass_template =
103 R"TREELITETEMPLATE(
104  for (int i = 0; i < {num_class}; ++i) {{
105  result[i] = sum[i] + (float)({global_bias});
106  }}
107  if (!pred_margin) {{
108  return pred_transform(result);
109  }} else {{
110  return {num_class};
111  }}
112 )TREELITETEMPLATE"; // only for multiclass classification
113 
114 const char* const return_template =
115 R"TREELITETEMPLATE(
116  sum += (float)({global_bias});
117  if (!pred_margin) {{
118  return pred_transform(sum);
119  }} else {{
120  return sum;
121  }}
122 )TREELITETEMPLATE";
123 
124 const char* const arrays_template = R"TREELITETEMPLATE(
125 #include "header.h"
126 
127 {nodes}
128 )TREELITETEMPLATE";
129 
130 // Returns formatted nodes[] and nodes_row_ptr[] arrays
131 // nodes[]: stores nodes from all decision trees
132 // nodes_row_ptr[]: marks bounaries between decision trees. The nodes belonging to Tree [i] are
133 // found in nodes[nodes_row_ptr[i]:nodes_row_ptr[i+1]]
134 inline std::pair<std::string, std::string> FormatNodesArray(
135  const treelite::ModelImpl<float, float>& model) {
138  int node_count = 0;
139  nodes_row_ptr << "0";
140  for (const auto& tree : model.trees) {
141  for (int nid = 0; nid < tree.num_nodes; ++nid) {
142  if (tree.IsLeaf(nid)) {
143  CHECK(!tree.HasLeafVector(nid))
144  << "multi-class random forest classifier is not supported in FailSafeCompiler";
145  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
146  "sindex"_a = 0,
147  "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.LeafValue(nid)),
148  "cleft"_a = -1,
149  "cright"_a = -1);
150  } else {
151  CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
152  && !tree.HasMatchingCategories(nid))
153  << "categorical splits are not supported in FailSafeCompiler";
154  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
155  "sindex"_a
156  = (tree.SplitIndex(nid) |(static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31U)),
157  "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.Threshold(nid)),
158  "cleft"_a = tree.LeftChild(nid),
159  "cright"_a = tree.RightChild(nid));
160  }
161  }
162  node_count += tree.num_nodes;
163  nodes_row_ptr << std::to_string(node_count);
164  }
165  return std::make_pair(fmt::format("const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
166  fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
167  nodes_row_ptr.str()));
168 }
169 
170 // Variant of FormatNodesArray(), where nodes[] array is dumped as an ELF binary
171 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
172  const treelite::ModelImpl<float, float>& model) {
173  std::vector<char> nodes_elf;
175 
177  NodeStructValue val;
178  int node_count = 0;
179  nodes_row_ptr << "0";
180  for (const auto& tree : model.trees) {
181  for (int nid = 0; nid < tree.num_nodes; ++nid) {
182  if (tree.IsLeaf(nid)) {
183  CHECK(!tree.HasLeafVector(nid))
184  << "multi-class random forest classifier is not supported in FailSafeCompiler";
185  val = {0, static_cast<float>(tree.LeafValue(nid)), -1, -1};
186  } else {
187  CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
188  && !tree.HasMatchingCategories(nid))
189  << "categorical splits are not supported in FailSafeCompiler";
190  val = {(tree.SplitIndex(nid) | (static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31)),
191  static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
192  }
193  const size_t beg = nodes_elf.size();
194  nodes_elf.resize(beg + sizeof(NodeStructValue));
195  std::memcpy(&nodes_elf[beg], &val, sizeof(NodeStructValue));
196  }
197  node_count += tree.num_nodes;
198  nodes_row_ptr << std::to_string(node_count);
199  }
201 
202  return std::make_pair(nodes_elf, fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
203  nodes_row_ptr.str()));
204 }
205 
206 // Get the comparison op used in the tree ensemble model
207 // If splits have more than one op, throw an error
208 inline std::string GetCommonOp(const treelite::ModelImpl<float, float>& model) {
209  std::set<treelite::Operator> ops;
210  for (const auto& tree : model.trees) {
211  for (int nid = 0; nid < tree.num_nodes; ++nid) {
212  if (!tree.IsLeaf(nid)) {
213  ops.insert(tree.ComparisonOp(nid));
214  }
215  }
216  }
217  // sanity check: all numerical splits must have identical comparison operators
218  CHECK_EQ(ops.size(), 1)
219  << "FailSafeCompiler only supports models where all splits use identical comparison operator.";
220  return treelite::OpName(*ops.begin());
221 }
222 
223 
224 // Test whether a string ends with a given suffix
225 inline bool EndsWith(const std::string& str, const std::string& suffix) {
226  return (str.size() >= suffix.size()
227  && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
228 }
229 
230 } // anonymous namespace
231 
232 namespace treelite {
233 namespace compiler {
234 
235 DMLC_REGISTRY_FILE_TAG(failsafe);
236 
237 class FailSafeCompiler : public Compiler {
238  public:
239  explicit FailSafeCompiler(const CompilerParam& param)
240  : param(param) {
241  if (param.verbose > 0) {
242  LOG(INFO) << "Using FailSafeCompiler";
243  }
244  if (param.annotate_in != "NULL") {
245  LOG(INFO) << "Warning: 'annotate_in' parameter is not applicable for "
246  "FailSafeCompiler";
247  }
248  if (param.quantize > 0) {
249  LOG(INFO) << "Warning: 'quantize' parameter is not applicable for "
250  "FailSafeCompiler";
251  }
252  if (param.parallel_comp > 0) {
253  LOG(INFO) << "Warning: 'parallel_comp' parameter is not applicable for "
254  "FailSafeCompiler";
255  }
256  if (std::isfinite(param.code_folding_req)) {
257  LOG(INFO) << "Warning: 'code_folding_req' parameter is not applicable "
258  "for FailSafeCompiler";
259  }
260  }
261 
262  CompiledModel Compile(const Model& model_ptr) override {
263  CHECK(model_ptr.GetThresholdType() == TypeInfo::kFloat32
264  && model_ptr.GetLeafOutputType() == TypeInfo::kFloat32)
265  << "Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs";
266  const auto& model = dynamic_cast<const ModelImpl<float, float>&>(model_ptr);
267 
268  CompiledModel cm;
269  cm.backend = "native";
270 
271  num_feature_ = model.num_feature;
272  num_class_ = model.task_param.num_class;
273  CHECK(!model.average_tree_output)
274  << "Averaging tree output is not supported in FailSafeCompiler";
275  CHECK(model.task_type == TaskType::kBinaryClfRegr
276  || model.task_type == TaskType::kMultiClfGrovePerClass)
277  << "Model task type unsupported by FailSafeCompiler";
278  CHECK_EQ(model.task_param.leaf_vector_size, 1)
279  << "Model with leaf vectors is not support by FailSafeCompiler";
280  pred_tranform_func_ = PredTransformFunction("native", model_ptr);
281  files_.clear();
282 
283  const char* predict_function_signature
284  = (num_class_ > 1) ?
285  "size_t predict_multiclass(union Entry* data, int pred_margin, float* result)"
286  : "float predict(union Entry* data, int pred_margin)";
287 
288  std::ostringstream main_program;
289  std::string accumulator_definition
290  = (num_class_ > 1
291  ? fmt::format("float sum[{num_class}] = {{0.0f}}",
292  "num_class"_a = num_class_)
293  : std::string("float sum = 0.0f"));
294 
295  std::string output_statement
296  = (num_class_ > 1
297  ? fmt::format("sum[tree_id % {num_class}] += tree[nid].info.leaf_value;",
298  "num_class"_a = num_class_)
299  : std::string("sum += tree[nid].info.leaf_value;"));
300 
301  std::string return_statement
302  = (num_class_ > 1
303  ? fmt::format(return_multiclass_template,
304  "num_class"_a = num_class_,
305  "global_bias"_a
306  = compiler::common_util::ToStringHighPrecision(model.param.global_bias))
307  : fmt::format(return_template,
308  "global_bias"_a
309  = compiler::common_util::ToStringHighPrecision(model.param.global_bias)));
310 
311  std::string nodes, nodes_row_ptr;
312  std::vector<char> nodes_elf;
313  if (param.dump_array_as_elf > 0) {
314  if (param.verbose > 0) {
315  LOG(INFO) << "Dumping arrays as an ELF relocatable object...";
316  }
317  std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
318  } else {
319  std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
320  }
321 
322  const ModelParam model_param = model.param;
323  const std::string query_functions_definition
324  = fmt::format(native::query_functions_definition_template,
325  "num_class"_a = num_class_,
326  "num_feature"_a = num_feature_,
327  "pred_transform"_a = model_param.pred_transform,
328  "sigmoid_alpha"_a = model_param.sigmoid_alpha,
329  "global_bias"_a = model_param.global_bias,
330  "threshold_type_str"_a = TypeInfoToString(TypeToInfo<float>()),
331  "leaf_output_type_str"_a = TypeInfoToString(TypeToInfo<float>()));
332 
333  main_program << fmt::format(main_template,
334  "nodes_row_ptr"_a = nodes_row_ptr,
335  "query_functions_definition"_a = query_functions_definition,
336  "pred_transform_function"_a = pred_tranform_func_,
337  "predict_function_signature"_a = predict_function_signature,
338  "num_tree"_a = model.trees.size(),
339  "compare_op"_a = GetCommonOp(model),
340  "accumulator_definition"_a = accumulator_definition,
341  "output_statement"_a = output_statement,
342  "return_statement"_a = return_statement);
343 
344  files_["main.c"] = CompiledModel::FileEntry(main_program.str());
345 
346  if (param.dump_array_as_elf > 0) {
347  files_["arrays.o"] = CompiledModel::FileEntry(std::move(nodes_elf));
348  } else {
349  files_["arrays.c"] = CompiledModel::FileEntry(fmt::format(arrays_template,
350  "nodes"_a = nodes));
351  }
352 
353  const std::string query_functions_prototype
354  = fmt::format(native::query_functions_prototype_template,
355  "dllexport"_a = DLLEXPORT_KEYWORD);
356  files_["header.h"] = CompiledModel::FileEntry(fmt::format(header_template,
357  "dllexport"_a = DLLEXPORT_KEYWORD,
358  "query_functions_prototype"_a = query_functions_prototype,
359  "predict_function_signature"_a = predict_function_signature));
360 
361  {
362  /* write recipe.json */
363  std::vector<std::unordered_map<std::string, std::string>> source_list;
364  std::vector<std::string> extra_file_list;
365  for (const auto& kv : files_) {
366  if (EndsWith(kv.first, ".c")) {
367  const size_t line_count
368  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
369  source_list.push_back({ {"name",
370  kv.first.substr(0, kv.first.length() - 2)},
371  {"length", std::to_string(line_count)} });
372  } else if (EndsWith(kv.first, ".o")) {
373  extra_file_list.push_back(kv.first);
374  }
375  }
376  std::ostringstream oss;
377  std::unique_ptr<dmlc::JSONWriter> writer(new dmlc::JSONWriter(&oss));
378  writer->BeginObject();
379  writer->WriteObjectKeyValue("target", param.native_lib_name);
380  writer->WriteObjectKeyValue("sources", source_list);
381  if (!extra_file_list.empty()) {
382  writer->WriteObjectKeyValue("extra", extra_file_list);
383  }
384  writer->EndObject();
385  files_["recipe.json"] = CompiledModel::FileEntry(oss.str());
386  }
387  cm.files = std::move(files_);
388  return cm;
389  }
390 
391  private:
392  CompilerParam param;
393  int num_feature_;
394  unsigned int num_class_;
395  std::string pred_tranform_func_;
396  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
397 };
398 
400 .describe("Simple compiler to express trees as a tight for-loop")
401 .set_body([](const CompilerParam& param) -> Compiler* {
402  return new FailSafeCompiler(param);
403  });
404 } // namespace compiler
405 } // namespace treelite
Parameters for tree compiler.
CompiledModel Compile(const Model &model_ptr) override
convert tree ensemble model
Definition: failsafe.cc:262
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:43
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:599
model structure for tree ensemble
void AllocateELFHeader(std::vector< char > *elf_buffer)
Pre-allocate space in a buffer to fit an ELF header.
template for header
float global_bias
global bias of the model
Definition: tree.h:606
interface of compiler
Definition: compiler.h:54
Interface of compiler that compiles a tree ensemble model.
template for main function
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:705
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Definition: compiler.h:27
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
Definition: compiler.h:92
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
thin wrapper for tree ensemble model
Definition: tree.h:632
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
Generate a relocatable object file containing a constant, read-only array.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:673
void FormatArrayAsELF(std::vector< char > *elf_buffer)
Format a relocatable ELF object file containing a constant, read-only array.
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:591