12 #include <fmt/format.h> 13 #include <unordered_map> 18 #include "./pred_transform.h" 24 #if defined(_MSC_VER) || defined(_WIN32) 25 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 27 #define DLLEXPORT_KEYWORD "" 34 struct NodeStructValue {
41 const char*
const header_template = R
"TREELITETEMPLATE( 64 extern const struct Node nodes[]; 65 extern const int nodes_row_ptr[]; 67 {query_functions_prototype} 68 {dllexport}{predict_function_signature}; 71 const char*
const main_template = R
"TREELITETEMPLATE( 76 {query_functions_definition} 78 {pred_transform_function} 80 {predict_function_signature} {{ 81 {accumulator_definition}; 83 for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{ 85 const struct Node* tree = &nodes[nodes_row_ptr[tree_id]]; 86 while (tree[nid].cleft != -1) {{ 87 const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U); 88 const unsigned char default_left = (tree[nid].sindex >> 31) != 0; 89 if (data[feature_id].missing == -1) {{ 90 nid = (default_left ? tree[nid].cleft : tree[nid].cright); 92 nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold 93 ? tree[nid].cleft : tree[nid].cright); 102 const char*
const return_multiclass_template =
104 for (int i = 0; i < {num_class}; ++i) {{ 105 result[i] = sum[i] + (float)({global_bias}); 108 return pred_transform(result); 114 const char*
const return_template =
116 sum += (float)({global_bias}); 118 return pred_transform(sum); 124 const char*
const arrays_template = R
"TREELITETEMPLATE( 134 inline std::pair<std::string, std::string> FormatNodesArray(
139 nodes_row_ptr <<
"0";
140 for (
const auto& tree : model.
trees) {
141 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
142 if (tree.IsLeaf(nid)) {
143 CHECK(!tree.HasLeafVector(nid))
144 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
145 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
147 "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.LeafValue(nid)),
151 CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
152 && !tree.HasMatchingCategories(nid))
153 <<
"categorical splits are not supported in FailSafeCompiler";
154 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
156 = (tree.SplitIndex(nid) |(
static_cast<uint32_t
>(tree.DefaultLeft(nid)) << 31U)),
157 "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.Threshold(nid)),
158 "cleft"_a = tree.LeftChild(nid),
159 "cright"_a = tree.RightChild(nid));
162 node_count += tree.num_nodes;
163 nodes_row_ptr << std::to_string(node_count);
165 return std::make_pair(fmt::format(
"const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
166 fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
167 nodes_row_ptr.str()));
171 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
173 std::vector<char> nodes_elf;
179 nodes_row_ptr <<
"0";
180 for (
const auto& tree : model.
trees) {
181 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
182 if (tree.IsLeaf(nid)) {
183 CHECK(!tree.HasLeafVector(nid))
184 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
185 val = {0,
static_cast<float>(tree.LeafValue(nid)), -1, -1};
187 CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
188 && !tree.HasMatchingCategories(nid))
189 <<
"categorical splits are not supported in FailSafeCompiler";
190 val = {(tree.SplitIndex(nid) | (
static_cast<uint32_t
>(tree.DefaultLeft(nid)) << 31)),
191 static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
193 const size_t beg = nodes_elf.size();
194 nodes_elf.resize(beg +
sizeof(NodeStructValue));
195 std::memcpy(&nodes_elf[beg], &val,
sizeof(NodeStructValue));
197 node_count += tree.num_nodes;
198 nodes_row_ptr << std::to_string(node_count);
202 return std::make_pair(nodes_elf, fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
203 nodes_row_ptr.str()));
209 std::set<treelite::Operator> ops;
210 for (
const auto& tree : model.
trees) {
211 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
212 if (!tree.IsLeaf(nid)) {
213 ops.insert(tree.ComparisonOp(nid));
218 CHECK_EQ(ops.size(), 1)
219 <<
"FailSafeCompiler only supports models where all splits use identical comparison operator.";
225 inline bool EndsWith(
const std::string& str,
const std::string& suffix) {
226 return (str.size() >= suffix.size()
227 && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
235 DMLC_REGISTRY_FILE_TAG(failsafe);
242 LOG(INFO) <<
"Using FailSafeCompiler";
245 LOG(INFO) <<
"Warning: 'annotate_in' parameter is not applicable for " 249 LOG(INFO) <<
"Warning: 'quantize' parameter is not applicable for " 253 LOG(INFO) <<
"Warning: 'parallel_comp' parameter is not applicable for " 257 LOG(INFO) <<
"Warning: 'code_folding_req' parameter is not applicable " 258 "for FailSafeCompiler";
263 CHECK(model_ptr.GetThresholdType() == TypeInfo::kFloat32
264 && model_ptr.GetLeafOutputType() == TypeInfo::kFloat32)
265 <<
"Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs";
269 cm.backend =
"native";
272 num_class_ = model.task_param.num_class;
273 CHECK(!model.average_tree_output)
274 <<
"Averaging tree output is not supported in FailSafeCompiler";
275 CHECK(model.task_type == TaskType::kBinaryClfRegr
276 || model.task_type == TaskType::kMultiClfGrovePerClass)
277 <<
"Model task type unsupported by FailSafeCompiler";
278 CHECK_EQ(model.task_param.leaf_vector_size, 1)
279 <<
"Model with leaf vectors is not support by FailSafeCompiler";
280 pred_tranform_func_ = PredTransformFunction(
"native", model_ptr);
283 const char* predict_function_signature
285 "size_t predict_multiclass(union Entry* data, int pred_margin, float* result)" 286 :
"float predict(union Entry* data, int pred_margin)";
288 std::ostringstream main_program;
289 std::string accumulator_definition
291 ? fmt::format(
"float sum[{num_class}] = {{0.0f}}",
292 "num_class"_a = num_class_)
293 : std::string(
"float sum = 0.0f"));
295 std::string output_statement
297 ? fmt::format(
"sum[tree_id % {num_class}] += tree[nid].info.leaf_value;",
298 "num_class"_a = num_class_)
299 : std::string(
"sum += tree[nid].info.leaf_value;"));
301 std::string return_statement
303 ? fmt::format(return_multiclass_template,
304 "num_class"_a = num_class_,
306 = compiler::common_util::ToStringHighPrecision(model.param.global_bias))
307 : fmt::format(return_template,
309 = compiler::common_util::ToStringHighPrecision(model.param.global_bias)));
311 std::string nodes, nodes_row_ptr;
312 std::vector<char> nodes_elf;
315 LOG(INFO) <<
"Dumping arrays as an ELF relocatable object...";
317 std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
319 std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
323 const std::string query_functions_definition
324 = fmt::format(native::query_functions_definition_template,
325 "num_class"_a = num_class_,
326 "num_feature"_a = num_feature_,
333 main_program << fmt::format(main_template,
334 "nodes_row_ptr"_a = nodes_row_ptr,
335 "query_functions_definition"_a = query_functions_definition,
336 "pred_transform_function"_a = pred_tranform_func_,
337 "predict_function_signature"_a = predict_function_signature,
338 "num_tree"_a = model.trees.size(),
339 "compare_op"_a = GetCommonOp(model),
340 "accumulator_definition"_a = accumulator_definition,
341 "output_statement"_a = output_statement,
342 "return_statement"_a = return_statement);
353 const std::string query_functions_prototype
354 = fmt::format(native::query_functions_prototype_template,
355 "dllexport"_a = DLLEXPORT_KEYWORD);
357 "dllexport"_a = DLLEXPORT_KEYWORD,
358 "query_functions_prototype"_a = query_functions_prototype,
359 "predict_function_signature"_a = predict_function_signature));
363 std::vector<std::unordered_map<std::string, std::string>> source_list;
364 std::vector<std::string> extra_file_list;
365 for (
const auto& kv : files_) {
366 if (EndsWith(kv.first,
".c")) {
367 const size_t line_count
368 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
369 source_list.push_back({ {
"name",
370 kv.first.substr(0, kv.first.length() - 2)},
371 {
"length", std::to_string(line_count)} });
372 }
else if (EndsWith(kv.first,
".o")) {
373 extra_file_list.push_back(kv.first);
376 std::ostringstream oss;
377 std::unique_ptr<dmlc::JSONWriter> writer(
new dmlc::JSONWriter(&oss));
378 writer->BeginObject();
380 writer->WriteObjectKeyValue(
"sources", source_list);
381 if (!extra_file_list.empty()) {
382 writer->WriteObjectKeyValue(
"extra", extra_file_list);
387 cm.files = std::move(files_);
394 unsigned int num_class_;
395 std::string pred_tranform_func_;
396 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
400 .describe(
"Simple compiler to express trees as a tight for-loop")
Parameters for tree compiler.
CompiledModel Compile(const Model &model_ptr) override
convert tree ensemble model
std::string OpName(Operator op)
get string representation of comparison operator
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
model structure for tree ensemble
float global_bias
global bias of the model
Interface of compiler that compiles a tree ensemble model.
template for main function
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
thin wrapper for tree ensemble model
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function