9 #include <unordered_map> 16 #include <treelite/common.h> 17 #include <fmt/format.h> 19 #include "./pred_transform.h" 22 #if defined(_MSC_VER) || defined(_WIN32) 23 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 25 #define DLLEXPORT_KEYWORD "" 32 struct NodeStructValue {
39 const char* header_template = R
"TREELITETEMPLATE( 62 extern const struct Node nodes[]; 63 extern const int nodes_row_ptr[]; 65 {dllexport}size_t get_num_output_group(void); 66 {dllexport}size_t get_num_feature(void); 67 {dllexport}{predict_function_signature}; 70 const char* main_template = R
"TREELITETEMPLATE( 75 size_t get_num_output_group(void) {{ 76 return {num_output_group}; 79 size_t get_num_feature(void) {{ 83 {pred_transform_function} 85 {predict_function_signature} {{ 86 {accumulator_definition}; 88 for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{ 90 const struct Node* tree = &nodes[nodes_row_ptr[tree_id]]; 91 while (tree[nid].cleft != -1) {{ 92 const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U); 93 const unsigned char default_left = (tree[nid].sindex >> 31) != 0; 94 if (data[feature_id].missing == -1) {{ 95 nid = (default_left ? tree[nid].cleft : tree[nid].cright); 97 nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold 98 ? tree[nid].cleft : tree[nid].cright); 107 const char* return_multiclass_template =
109 for (int i = 0; i < {num_output_group}; ++i) {{ 110 result[i] = sum[i] + (float)({global_bias}); 113 return pred_transform(result); 115 return {num_output_group}; 119 const char* return_template =
121 sum += (float)({global_bias}); 123 return pred_transform(sum); 129 const char* arrays_template = R
"TREELITETEMPLATE( 139 inline std::pair<std::string, std::string> FormatNodesArray(
const treelite::Model& model) {
143 nodes_row_ptr <<
"0";
144 for (
const auto& tree : model.
trees) {
145 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
146 const auto& node = tree[nid];
147 if (node.is_leaf()) {
148 CHECK(!node.has_leaf_vector())
149 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
150 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
152 "info"_a = treelite::common::ToStringHighPrecision(node.leaf_value()),
156 CHECK(node.split_type() == treelite::SplitFeatureType::kNumerical
157 && node.left_categories().empty())
158 <<
"categorical splits are not supported in FailSafeCompiler";
159 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
160 "sindex"_a = (node.split_index() | (
static_cast<uint32_t
>(node.default_left()) << 31)),
161 "info"_a = treelite::common::ToStringHighPrecision(node.threshold()),
162 "cleft"_a = node.cleft(),
163 "cright"_a = node.cright());
166 node_count += tree.num_nodes;
167 nodes_row_ptr << std::to_string(node_count);
169 return std::make_pair(fmt::format(
"const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
170 fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
171 nodes_row_ptr.str()));
175 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
const treelite::Model& model) {
176 std::vector<char> nodes_elf;
182 nodes_row_ptr <<
"0";
183 for (
const auto& tree : model.
trees) {
184 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
185 const auto& node = tree[nid];
186 if (node.is_leaf()) {
187 CHECK(!node.has_leaf_vector())
188 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
189 val = {0,
static_cast<float>(node.leaf_value()), -1, -1};
191 CHECK(node.split_type() == treelite::SplitFeatureType::kNumerical
192 && node.left_categories().empty())
193 <<
"categorical splits are not supported in FailSafeCompiler";
194 val = {(node.split_index() | (
static_cast<uint32_t
>(node.default_left()) << 31)),
195 static_cast<float>(node.threshold()), node.cleft(), node.cright()};
197 const size_t beg = nodes_elf.size();
198 nodes_elf.resize(beg +
sizeof(NodeStructValue));
199 std::memcpy(&nodes_elf[beg], &val,
sizeof(NodeStructValue));
201 node_count += tree.num_nodes;
202 nodes_row_ptr << std::to_string(node_count);
206 return std::make_pair(nodes_elf, fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
207 nodes_row_ptr.str()));
213 std::set<treelite::Operator> ops;
214 for (
const auto& tree : model.
trees) {
215 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
216 const auto& node = tree[nid];
217 if (!node.is_leaf()) {
218 ops.insert(node.comparison_op());
223 CHECK_EQ(ops.size(), 1)
224 <<
"FailSafeCompiler only supports models where all splits use identical comparison operator.";
230 inline bool EndsWith(
const std::string& str,
const std::string& suffix) {
231 return (str.size() >= suffix.size()
232 && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
240 DMLC_REGISTRY_FILE_TAG(failsafe);
247 LOG(INFO) <<
"Using FailSafeCompiler";
250 LOG(INFO) <<
"Warning: 'annotate_in' parameter is not applicable for " 254 LOG(INFO) <<
"Warning: 'quantize' parameter is not applicable for " 258 LOG(INFO) <<
"Warning: 'parallel_comp' parameter is not applicable for " 262 LOG(INFO) <<
"Warning: 'code_folding_req' parameter is not applicable " 263 "for FailSafeCompiler";
269 cm.backend =
"native";
274 <<
"Only gradient boosted trees supported in FailSafeCompiler";
275 pred_tranform_func_ = PredTransformFunction(
"native", model);
278 const char* predict_function_signature
279 = (num_output_group_ > 1) ?
280 "size_t predict_multiclass(union Entry* data, int pred_margin, " 282 :
"float predict(union Entry* data, int pred_margin)";
284 std::ostringstream main_program;
285 std::string accumulator_definition
286 = (num_output_group_ > 1
287 ? fmt::format(
"float sum[{num_output_group}] = {{0.0f}}",
288 "num_output_group"_a = num_output_group_)
289 : std::string(
"float sum = 0.0f"));
291 std::string output_statement
292 = (num_output_group_ > 1
293 ? fmt::format(
"sum[tree_id % {num_output_group}] += tree[nid].info.leaf_value;",
294 "num_output_group"_a = num_output_group_)
295 : std::string(
"sum += tree[nid].info.leaf_value;"));
297 std::string return_statement
298 = (num_output_group_ > 1
299 ? fmt::format(return_multiclass_template,
300 "num_output_group"_a = num_output_group_,
302 : fmt::format(return_template,
305 std::string nodes, nodes_row_ptr;
306 std::vector<char> nodes_elf;
309 LOG(INFO) <<
"Dumping arrays as an ELF relocatable object...";
311 std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
313 std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
316 main_program << fmt::format(main_template,
317 "nodes_row_ptr"_a = nodes_row_ptr,
318 "pred_transform_function"_a = pred_tranform_func_,
319 "predict_function_signature"_a = predict_function_signature,
320 "num_output_group"_a = num_output_group_,
321 "num_feature"_a = num_feature_,
322 "num_tree"_a = model.
trees.size(),
323 "compare_op"_a = GetCommonOp(model),
324 "accumulator_definition"_a = accumulator_definition,
325 "output_statement"_a = output_statement,
326 "return_statement"_a = return_statement);
338 "dllexport"_a = DLLEXPORT_KEYWORD,
339 "predict_function_signature"_a = predict_function_signature));
343 std::vector<std::unordered_map<std::string, std::string>> source_list;
344 std::vector<std::string> extra_file_list;
345 for (
const auto& kv : files_) {
346 if (EndsWith(kv.first,
".c")) {
347 const size_t line_count
348 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
349 source_list.push_back({ {
"name",
350 kv.first.substr(0, kv.first.length() - 2)},
351 {
"length", std::to_string(line_count)} });
352 }
else if (EndsWith(kv.first,
".o")) {
353 extra_file_list.push_back(kv.first);
356 std::ostringstream oss;
357 auto writer = common::make_unique<dmlc::JSONWriter>(&oss);
358 writer->BeginObject();
360 writer->WriteObjectKeyValue(
"sources", source_list);
361 if (!extra_file_list.empty()) {
362 writer->WriteObjectKeyValue(
"extra", extra_file_list);
367 cm.files = std::move(files_);
374 int num_output_group_;
375 std::string pred_tranform_func_;
376 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
380 .describe(
"Simple compiler to express trees as a tight for-loop")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
CompiledModel Compile(const Model &model) override
convert tree ensemble model
thin wrapper for tree ensemble model
std::string OpName(Operator op)
get string representation of comparsion operator
std::vector< Tree > trees
member trees
parameters for tree compiler
ModelParam param
extra parameters
float global_bias
global bias of the model
Parameters for tree compiler.
Interface of compiler that compiles a tree ensemble model.
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...