12 #include <fmt/format.h> 13 #include <unordered_map> 18 #include "./pred_transform.h" 22 #if defined(_MSC_VER) || defined(_WIN32) 23 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 25 #define DLLEXPORT_KEYWORD "" 32 struct NodeStructValue {
39 const char* header_template = R
"TREELITETEMPLATE( 62 extern const struct Node nodes[]; 63 extern const int nodes_row_ptr[]; 65 {dllexport}size_t get_num_output_group(void); 66 {dllexport}size_t get_num_feature(void); 67 {dllexport}{predict_function_signature}; 70 const char* main_template = R
"TREELITETEMPLATE( 75 size_t get_num_output_group(void) {{ 76 return {num_output_group}; 79 size_t get_num_feature(void) {{ 83 {pred_transform_function} 85 {predict_function_signature} {{ 86 {accumulator_definition}; 88 for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{ 90 const struct Node* tree = &nodes[nodes_row_ptr[tree_id]]; 91 while (tree[nid].cleft != -1) {{ 92 const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U); 93 const unsigned char default_left = (tree[nid].sindex >> 31) != 0; 94 if (data[feature_id].missing == -1) {{ 95 nid = (default_left ? tree[nid].cleft : tree[nid].cright); 97 nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold 98 ? tree[nid].cleft : tree[nid].cright); 107 const char* return_multiclass_template =
109 for (int i = 0; i < {num_output_group}; ++i) {{ 110 result[i] = sum[i] + (float)({global_bias}); 113 return pred_transform(result); 115 return {num_output_group}; 119 const char* return_template =
121 sum += (float)({global_bias}); 123 return pred_transform(sum); 129 const char* arrays_template = R
"TREELITETEMPLATE( 139 inline std::pair<std::string, std::string> FormatNodesArray(
const treelite::Model& model) {
143 nodes_row_ptr <<
"0";
144 for (
const auto& tree : model.
trees) {
145 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
146 if (tree.IsLeaf(nid)) {
147 CHECK(!tree.HasLeafVector(nid))
148 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
149 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
151 "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.LeafValue(nid)),
155 CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
156 && tree.LeftCategories(nid).empty())
157 <<
"categorical splits are not supported in FailSafeCompiler";
158 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
160 = (tree.SplitIndex(nid) |(
static_cast<uint32_t
>(tree.DefaultLeft(nid)) << 31U)),
161 "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.Threshold(nid)),
162 "cleft"_a = tree.LeftChild(nid),
163 "cright"_a = tree.RightChild(nid));
166 node_count += tree.num_nodes;
167 nodes_row_ptr << std::to_string(node_count);
169 return std::make_pair(fmt::format(
"const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
170 fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
171 nodes_row_ptr.str()));
175 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
const treelite::Model& model) {
176 std::vector<char> nodes_elf;
182 nodes_row_ptr <<
"0";
183 for (
const auto& tree : model.
trees) {
184 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
185 if (tree.IsLeaf(nid)) {
186 CHECK(!tree.HasLeafVector(nid))
187 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
188 val = {0,
static_cast<float>(tree.LeafValue(nid)), -1, -1};
190 CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
191 && tree.LeftCategories(nid).empty())
192 <<
"categorical splits are not supported in FailSafeCompiler";
193 val = {(tree.SplitIndex(nid) | (
static_cast<uint32_t
>(tree.DefaultLeft(nid)) << 31)),
194 static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
196 const size_t beg = nodes_elf.size();
197 nodes_elf.resize(beg +
sizeof(NodeStructValue));
198 std::memcpy(&nodes_elf[beg], &val,
sizeof(NodeStructValue));
200 node_count += tree.num_nodes;
201 nodes_row_ptr << std::to_string(node_count);
205 return std::make_pair(nodes_elf, fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
206 nodes_row_ptr.str()));
212 std::set<treelite::Operator> ops;
213 for (
const auto& tree : model.
trees) {
214 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
215 if (!tree.IsLeaf(nid)) {
216 ops.insert(tree.ComparisonOp(nid));
221 CHECK_EQ(ops.size(), 1)
222 <<
"FailSafeCompiler only supports models where all splits use identical comparison operator.";
228 inline bool EndsWith(
const std::string& str,
const std::string& suffix) {
229 return (str.size() >= suffix.size()
230 && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
238 DMLC_REGISTRY_FILE_TAG(failsafe);
245 LOG(INFO) <<
"Using FailSafeCompiler";
248 LOG(INFO) <<
"Warning: 'annotate_in' parameter is not applicable for " 252 LOG(INFO) <<
"Warning: 'quantize' parameter is not applicable for " 256 LOG(INFO) <<
"Warning: 'parallel_comp' parameter is not applicable for " 260 LOG(INFO) <<
"Warning: 'code_folding_req' parameter is not applicable " 261 "for FailSafeCompiler";
267 cm.backend =
"native";
272 <<
"Only gradient boosted trees supported in FailSafeCompiler";
273 pred_tranform_func_ = PredTransformFunction(
"native", model);
276 const char* predict_function_signature
277 = (num_output_group_ > 1) ?
278 "size_t predict_multiclass(union Entry* data, int pred_margin, " 280 :
"float predict(union Entry* data, int pred_margin)";
282 std::ostringstream main_program;
283 std::string accumulator_definition
284 = (num_output_group_ > 1
285 ? fmt::format(
"float sum[{num_output_group}] = {{0.0f}}",
286 "num_output_group"_a = num_output_group_)
287 : std::string(
"float sum = 0.0f"));
289 std::string output_statement
290 = (num_output_group_ > 1
291 ? fmt::format(
"sum[tree_id % {num_output_group}] += tree[nid].info.leaf_value;",
292 "num_output_group"_a = num_output_group_)
293 : std::string(
"sum += tree[nid].info.leaf_value;"));
295 std::string return_statement
296 = (num_output_group_ > 1
297 ? fmt::format(return_multiclass_template,
298 "num_output_group"_a = num_output_group_,
301 : fmt::format(return_template,
305 std::string nodes, nodes_row_ptr;
306 std::vector<char> nodes_elf;
309 LOG(INFO) <<
"Dumping arrays as an ELF relocatable object...";
311 std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
313 std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
316 main_program << fmt::format(main_template,
317 "nodes_row_ptr"_a = nodes_row_ptr,
318 "pred_transform_function"_a = pred_tranform_func_,
319 "predict_function_signature"_a = predict_function_signature,
320 "num_output_group"_a = num_output_group_,
321 "num_feature"_a = num_feature_,
322 "num_tree"_a = model.
trees.size(),
323 "compare_op"_a = GetCommonOp(model),
324 "accumulator_definition"_a = accumulator_definition,
325 "output_statement"_a = output_statement,
326 "return_statement"_a = return_statement);
338 "dllexport"_a = DLLEXPORT_KEYWORD,
339 "predict_function_signature"_a = predict_function_signature));
343 std::vector<std::unordered_map<std::string, std::string>> source_list;
344 std::vector<std::string> extra_file_list;
345 for (
const auto& kv : files_) {
346 if (EndsWith(kv.first,
".c")) {
347 const size_t line_count
348 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
349 source_list.push_back({ {
"name",
350 kv.first.substr(0, kv.first.length() - 2)},
351 {
"length", std::to_string(line_count)} });
352 }
else if (EndsWith(kv.first,
".o")) {
353 extra_file_list.push_back(kv.first);
356 std::ostringstream oss;
357 std::unique_ptr<dmlc::JSONWriter> writer(
new dmlc::JSONWriter(&oss));
358 writer->BeginObject();
360 writer->WriteObjectKeyValue(
"sources", source_list);
361 if (!extra_file_list.empty()) {
362 writer->WriteObjectKeyValue(
"extra", extra_file_list);
367 cm.files = std::move(files_);
374 int num_output_group_;
375 std::string pred_tranform_func_;
376 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
380 .describe(
"Simple compiler to express trees as a tight for-loop")
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Parameters for tree compiler.
CompiledModel Compile(const Model &model) override
convert tree ensemble model
thin wrapper for tree ensemble model
std::string OpName(Operator op)
get string representation of comparsion operator
std::vector< Tree > trees
member trees
parameters for tree compiler
ModelParam param
extra parameters
model structure for tree ensemble
float global_bias
global bias of the model
Interface of compiler that compiles a tree ensemble model.
std::string native_lib_name
native lib name (without extension)
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
int dump_array_as_elf
Only applicable when compiler is set to failsafe. If set to a positive value, the fail-safe compiler ...
#define TREELITE_REGISTER_COMPILER(UniqueId, Name)
Macro to register compiler.
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...