Treelite
failsafe.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/compiler.h>
12 #include <treelite/logging.h>
13 #include <fmt/format.h>
14 #include <rapidjson/stringbuffer.h>
15 #include <rapidjson/writer.h>
16 #include <unordered_map>
17 #include <set>
18 #include <tuple>
19 #include <utility>
20 #include <cmath>
21 #include "./failsafe.h"
22 #include "./pred_transform.h"
23 #include "./common/format_util.h"
24 #include "./elf/elf_formatter.h"
25 #include "./native/main_template.h"
27 
28 #if defined(_MSC_VER) || defined(_WIN32)
29 #define DLLEXPORT_KEYWORD "__declspec(dllexport) "
30 #else
31 #define DLLEXPORT_KEYWORD ""
32 #endif
33 
34 using namespace fmt::literals;
35 
36 namespace {
37 
38 struct NodeStructValue {
39  unsigned int sindex;
40  float info;
41  int cleft;
42  int cright;
43 };
44 
45 const char* const header_template = R"TREELITETEMPLATE(
46 #include <stdlib.h>
47 #include <string.h>
48 #include <math.h>
49 #include <stdint.h>
50 
51 union Entry {{
52  int missing;
53  float fvalue;
54 }};
55 
56 union NodeInfo {{
57  float leaf_value;
58  float threshold;
59 }};
60 
61 struct Node {{
62  unsigned int sindex;
63  union NodeInfo info;
64  int cleft;
65  int cright;
66 }};
67 
68 extern const struct Node nodes[];
69 extern const int nodes_row_ptr[];
70 
71 {query_functions_prototype}
72 {dllexport}{predict_function_signature};
73 )TREELITETEMPLATE";
74 
75 const char* const main_template = R"TREELITETEMPLATE(
76 #include "header.h"
77 
78 {nodes_row_ptr}
79 
80 {query_functions_definition}
81 
82 {pred_transform_function}
83 
84 {predict_function_signature} {{
85  {accumulator_definition};
86 
87  for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{
88  int nid = 0;
89  const struct Node* tree = &nodes[nodes_row_ptr[tree_id]];
90  while (tree[nid].cleft != -1) {{
91  const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U);
92  const unsigned char default_left = (tree[nid].sindex >> 31) != 0;
93  if (data[feature_id].missing == -1) {{
94  nid = (default_left ? tree[nid].cleft : tree[nid].cright);
95  }} else {{
96  nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold
97  ? tree[nid].cleft : tree[nid].cright);
98  }}
99  }}
100  {output_statement}
101  }}
102  {return_statement}
103 }}
104 )TREELITETEMPLATE";
105 
106 const char* const return_multiclass_template =
107 R"TREELITETEMPLATE(
108  for (int i = 0; i < {num_class}; ++i) {{
109  result[i] = sum[i] + (float)({global_bias});
110  }}
111  if (!pred_margin) {{
112  return pred_transform(result);
113  }} else {{
114  return {num_class};
115  }}
116 )TREELITETEMPLATE"; // only for multiclass classification
117 
118 const char* const return_template =
119 R"TREELITETEMPLATE(
120  sum += (float)({global_bias});
121  if (!pred_margin) {{
122  return pred_transform(sum);
123  }} else {{
124  return sum;
125  }}
126 )TREELITETEMPLATE";
127 
128 const char* const arrays_template = R"TREELITETEMPLATE(
129 #include "header.h"
130 
131 {nodes}
132 )TREELITETEMPLATE";
133 
134 // Returns formatted nodes[] and nodes_row_ptr[] arrays
135 // nodes[]: stores nodes from all decision trees
136 // nodes_row_ptr[]: marks bounaries between decision trees. The nodes belonging to Tree [i] are
137 // found in nodes[nodes_row_ptr[i]:nodes_row_ptr[i+1]]
138 inline std::pair<std::string, std::string> FormatNodesArray(
139  const treelite::ModelImpl<float, float>& model) {
142  int node_count = 0;
143  nodes_row_ptr << "0";
144  for (const auto& tree : model.trees) {
145  for (int nid = 0; nid < tree.num_nodes; ++nid) {
146  if (tree.IsLeaf(nid)) {
147  TREELITE_CHECK(!tree.HasLeafVector(nid))
148  << "multi-class random forest classifier is not supported in FailSafeCompiler";
149  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
150  "sindex"_a = 0,
151  "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.LeafValue(nid)),
152  "cleft"_a = -1,
153  "cright"_a = -1);
154  } else {
155  TREELITE_CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
156  && tree.MatchingCategories(nid).empty())
157  << "categorical splits are not supported in FailSafeCompiler";
158  nodes << fmt::format("{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
159  "sindex"_a
160  = (tree.SplitIndex(nid) |(static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31U)),
161  "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.Threshold(nid)),
162  "cleft"_a = tree.LeftChild(nid),
163  "cright"_a = tree.RightChild(nid));
164  }
165  }
166  node_count += tree.num_nodes;
167  nodes_row_ptr << std::to_string(node_count);
168  }
169  return std::make_pair(fmt::format("const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
170  fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
171  nodes_row_ptr.str()));
172 }
173 
174 // Variant of FormatNodesArray(), where nodes[] array is dumped as an ELF binary
175 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
176  const treelite::ModelImpl<float, float>& model) {
177  std::vector<char> nodes_elf;
179 
181  NodeStructValue val;
182  int node_count = 0;
183  nodes_row_ptr << "0";
184  for (const auto& tree : model.trees) {
185  for (int nid = 0; nid < tree.num_nodes; ++nid) {
186  if (tree.IsLeaf(nid)) {
187  TREELITE_CHECK(!tree.HasLeafVector(nid))
188  << "multi-class random forest classifier is not supported in FailSafeCompiler";
189  val = {0, static_cast<float>(tree.LeafValue(nid)), -1, -1};
190  } else {
191  TREELITE_CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
192  && tree.MatchingCategories(nid).empty())
193  << "categorical splits are not supported in FailSafeCompiler";
194  val = {(tree.SplitIndex(nid) | (static_cast<uint32_t>(tree.DefaultLeft(nid)) << 31)),
195  static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
196  }
197  const size_t beg = nodes_elf.size();
198  nodes_elf.resize(beg + sizeof(NodeStructValue));
199  std::memcpy(&nodes_elf[beg], &val, sizeof(NodeStructValue));
200  }
201  node_count += tree.num_nodes;
202  nodes_row_ptr << std::to_string(node_count);
203  }
205 
206  return std::make_pair(nodes_elf, fmt::format("const int nodes_row_ptr[] = {{\n{}\n}};",
207  nodes_row_ptr.str()));
208 }
209 
210 // Get the comparison op used in the tree ensemble model
211 // If splits have more than one op, throw an error
212 inline std::string GetCommonOp(const treelite::ModelImpl<float, float>& model) {
213  std::set<treelite::Operator> ops;
214  for (const auto& tree : model.trees) {
215  for (int nid = 0; nid < tree.num_nodes; ++nid) {
216  if (!tree.IsLeaf(nid)) {
217  ops.insert(tree.ComparisonOp(nid));
218  }
219  }
220  }
221  // sanity check: all numerical splits must have identical comparison operators
222  TREELITE_CHECK_EQ(ops.size(), 1)
223  << "FailSafeCompiler only supports models where all splits use identical comparison operator.";
224  return treelite::OpName(*ops.begin());
225 }
226 
227 
228 // Test whether a string ends with a given suffix
229 inline bool EndsWith(const std::string& str, const std::string& suffix) {
230  return (str.size() >= suffix.size()
231  && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
232 }
233 
234 } // anonymous namespace
235 
236 namespace treelite {
237 namespace compiler {
238 
240  public:
241  explicit FailSafeCompilerImpl(const CompilerParam& param) : param_(param) {}
242 
243  CompiledModel Compile(const Model& model_ptr) {
244  TREELITE_CHECK(model_ptr.GetThresholdType() == TypeInfo::kFloat32
245  && model_ptr.GetLeafOutputType() == TypeInfo::kFloat32)
246  << "Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs";
247  const auto& model = dynamic_cast<const ModelImpl<float, float>&>(model_ptr);
248 
249  CompiledModel cm;
250  cm.backend = "native";
251 
252  num_feature_ = model.num_feature;
253  num_class_ = model.task_param.num_class;
254  TREELITE_CHECK(!model.average_tree_output)
255  << "Averaging tree output is not supported in FailSafeCompiler";
256  TREELITE_CHECK(model.task_type == TaskType::kBinaryClfRegr
257  || model.task_type == TaskType::kMultiClfGrovePerClass)
258  << "Model task type unsupported by FailSafeCompiler";
259  TREELITE_CHECK_EQ(model.task_param.leaf_vector_size, 1)
260  << "Model with leaf vectors is not support by FailSafeCompiler";
261  pred_tranform_func_ = PredTransformFunction("native", model_ptr);
262  files_.clear();
263 
264  const char* predict_function_signature
265  = (num_class_ > 1) ?
266  "size_t predict_multiclass(union Entry* data, int pred_margin, float* result)"
267  : "float predict(union Entry* data, int pred_margin)";
268 
269  std::ostringstream main_program;
270  std::string accumulator_definition
271  = (num_class_ > 1
272  ? fmt::format("float sum[{num_class}] = {{0.0f}}",
273  "num_class"_a = num_class_)
274  : std::string("float sum = 0.0f"));
275 
276  std::string output_statement
277  = (num_class_ > 1
278  ? fmt::format("sum[tree_id % {num_class}] += tree[nid].info.leaf_value;",
279  "num_class"_a = num_class_)
280  : std::string("sum += tree[nid].info.leaf_value;"));
281 
282  std::string return_statement
283  = (num_class_ > 1
284  ? fmt::format(return_multiclass_template,
285  "num_class"_a = num_class_,
286  "global_bias"_a
287  = compiler::common_util::ToStringHighPrecision(model.param.global_bias))
288  : fmt::format(return_template,
289  "global_bias"_a
290  = compiler::common_util::ToStringHighPrecision(model.param.global_bias)));
291 
292  std::string nodes, nodes_row_ptr;
293  std::vector<char> nodes_elf;
294  if (param_.dump_array_as_elf > 0) {
295  if (param_.verbose > 0) {
296  TREELITE_LOG(INFO) << "Dumping arrays as an ELF relocatable object...";
297  }
298  std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
299  } else {
300  std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
301  }
302 
303  const ModelParam model_param = model.param;
304  const std::string query_functions_definition
305  = fmt::format(native::query_functions_definition_template,
306  "num_class"_a = num_class_,
307  "num_feature"_a = num_feature_,
308  "pred_transform"_a = model_param.pred_transform,
309  "sigmoid_alpha"_a = model_param.sigmoid_alpha,
310  "ratio_c"_a = model_param.ratio_c,
311  "global_bias"_a = model_param.global_bias,
312  "threshold_type_str"_a = TypeInfoToString(TypeToInfo<float>()),
313  "leaf_output_type_str"_a = TypeInfoToString(TypeToInfo<float>()));
314 
315  main_program << fmt::format(main_template,
316  "nodes_row_ptr"_a = nodes_row_ptr,
317  "query_functions_definition"_a = query_functions_definition,
318  "pred_transform_function"_a = pred_tranform_func_,
319  "predict_function_signature"_a = predict_function_signature,
320  "num_tree"_a = model.trees.size(),
321  "compare_op"_a = GetCommonOp(model),
322  "accumulator_definition"_a = accumulator_definition,
323  "output_statement"_a = output_statement,
324  "return_statement"_a = return_statement);
325 
326  files_["main.c"] = CompiledModel::FileEntry(main_program.str());
327 
328  if (param_.dump_array_as_elf > 0) {
329  files_["arrays.o"] = CompiledModel::FileEntry(std::move(nodes_elf));
330  } else {
331  files_["arrays.c"] = CompiledModel::FileEntry(fmt::format(arrays_template,
332  "nodes"_a = nodes));
333  }
334 
335  const std::string query_functions_prototype
336  = fmt::format(native::query_functions_prototype_template,
337  "dllexport"_a = DLLEXPORT_KEYWORD);
338  files_["header.h"] = CompiledModel::FileEntry(fmt::format(header_template,
339  "dllexport"_a = DLLEXPORT_KEYWORD,
340  "query_functions_prototype"_a = query_functions_prototype,
341  "predict_function_signature"_a = predict_function_signature));
342 
343  {
344  /* write recipe.json */
345  rapidjson::StringBuffer os;
346  rapidjson::Writer<rapidjson::StringBuffer> writer(os);
347 
348  writer.StartObject();
349  writer.Key("target");
350  writer.String(param_.native_lib_name.data(), param_.native_lib_name.size());
351  writer.Key("sources");
352  writer.StartArray();
353  std::vector<std::string> extra_file_list;
354  for (const auto& kv : files_) {
355  if (EndsWith(kv.first, ".c")) {
356  const size_t line_count
357  = std::count(kv.second.content.begin(), kv.second.content.end(), '\n');
358  writer.StartObject();
359  writer.Key("name");
360  std::string name = kv.first.substr(0, kv.first.length() - 2);
361  writer.String(name.data(), name.size());
362  writer.Key("length");
363  writer.Uint64(line_count);
364  writer.EndObject();
365  } else if (EndsWith(kv.first, ".o")) {
366  extra_file_list.push_back(kv.first);
367  }
368  }
369  writer.EndArray();
370  if (!extra_file_list.empty()) {
371  writer.Key("extra");
372  writer.StartArray();
373  for (const auto& extra_file : extra_file_list) {
374  writer.String(extra_file.data(), extra_file.size());
375  }
376  writer.EndArray();
377  }
378  writer.EndObject();
379 
380  files_["recipe.json"] = CompiledModel::FileEntry(os.GetString());
381  }
382  cm.files = std::move(files_);
383  return cm;
384  }
385 
386  CompilerParam QueryParam() const {
387  return param_;
388  }
389 
390  private:
391  CompilerParam param_;
392  int num_feature_;
393  unsigned int num_class_;
394  std::string pred_tranform_func_;
395  std::unordered_map<std::string, CompiledModel::FileEntry> files_;
396 };
397 
398 FailSafeCompiler::FailSafeCompiler(const CompilerParam& param)
399  : pimpl_(std::make_unique<FailSafeCompilerImpl>(param)) {
400  if (param.verbose > 0) {
401  TREELITE_LOG(INFO) << "Using FailSafeCompiler";
402  }
403  if (param.annotate_in != "NULL") {
404  TREELITE_LOG(INFO) << "Warning: 'annotate_in' parameter is not applicable for "
405  "FailSafeCompiler";
406  }
407  if (param.quantize > 0) {
408  TREELITE_LOG(INFO) << "Warning: 'quantize' parameter is not applicable for "
409  "FailSafeCompiler";
410  }
411  if (param.parallel_comp > 0) {
412  TREELITE_LOG(INFO) << "Warning: 'parallel_comp' parameter is not applicable for "
413  "FailSafeCompiler";
414  }
415  if (std::isfinite(param.code_folding_req)) {
416  TREELITE_LOG(INFO) << "Warning: 'code_folding_req' parameter is not applicable "
417  "for FailSafeCompiler";
418  }
419 }
420 
421 FailSafeCompiler::~FailSafeCompiler() = default;
422 
424 FailSafeCompiler::Compile(const Model& model) {
425  return pimpl_->Compile(model);
426 }
427 
429 FailSafeCompiler::QueryParam() const {
430  return pimpl_->QueryParam();
431 }
432 
433 } // namespace compiler
434 } // namespace treelite
Parameters for tree compiler.
std::string OpName(Operator op)
get string representation of comparison operator
Definition: base.h:43
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:614
model structure for tree ensemble
void AllocateELFHeader(std::vector< char > *elf_buffer)
Pre-allocate space in a buffer to fit an ELF header.
logging facility for Treelite
template for header
float global_bias
global bias of the model
Definition: tree.h:629
Interface of compiler that compiles a tree ensemble model.
template for main function
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:734
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: format_util.h:59
Definition: compiler.h:26
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
Definition: typeinfo.h:39
C code generator (fail-safe). The generated code will mimic prediction logic found in XGBoost...
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:622
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
thin wrapper for tree ensemble model
Definition: tree.h:655
int quantize
whether to quantize threshold points (0: no, >0: yes)
Formatting utilities.
int verbose
if >0, produce extra messages
Generate a relocatable object file containing a constant, read-only array.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:702
void FormatArrayAsELF(std::vector< char > *elf_buffer)
Format a relocatable ELF object file containing a constant, read-only array.
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:606