13 #include <fmt/format.h> 14 #include <rapidjson/stringbuffer.h> 15 #include <rapidjson/writer.h> 16 #include <unordered_map> 22 #include "./pred_transform.h" 28 #if defined(_MSC_VER) || defined(_WIN32) 29 #define DLLEXPORT_KEYWORD "__declspec(dllexport) " 31 #define DLLEXPORT_KEYWORD "" 38 struct NodeStructValue {
45 const char*
const header_template = R
"TREELITETEMPLATE( 68 extern const struct Node nodes[]; 69 extern const int nodes_row_ptr[]; 71 {query_functions_prototype} 72 {dllexport}{predict_function_signature}; 75 const char*
const main_template = R
"TREELITETEMPLATE( 80 {query_functions_definition} 82 {pred_transform_function} 84 {predict_function_signature} {{ 85 {accumulator_definition}; 87 for (int tree_id = 0; tree_id < {num_tree}; ++tree_id) {{ 89 const struct Node* tree = &nodes[nodes_row_ptr[tree_id]]; 90 while (tree[nid].cleft != -1) {{ 91 const unsigned feature_id = tree[nid].sindex & ((1U << 31) - 1U); 92 const unsigned char default_left = (tree[nid].sindex >> 31) != 0; 93 if (data[feature_id].missing == -1) {{ 94 nid = (default_left ? tree[nid].cleft : tree[nid].cright); 96 nid = (data[feature_id].fvalue {compare_op} tree[nid].info.threshold 97 ? tree[nid].cleft : tree[nid].cright); 106 const char*
const return_multiclass_template =
108 for (int i = 0; i < {num_class}; ++i) {{ 109 result[i] = sum[i] + (float)({global_bias}); 112 return pred_transform(result); 118 const char*
const return_template =
120 sum += (float)({global_bias}); 122 return pred_transform(sum); 128 const char*
const arrays_template = R
"TREELITETEMPLATE( 138 inline std::pair<std::string, std::string> FormatNodesArray(
143 nodes_row_ptr <<
"0";
144 for (
const auto& tree : model.
trees) {
145 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
146 if (tree.IsLeaf(nid)) {
147 TREELITE_CHECK(!tree.HasLeafVector(nid))
148 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
149 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
151 "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.LeafValue(nid)),
155 TREELITE_CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
156 && !tree.HasMatchingCategories(nid))
157 <<
"categorical splits are not supported in FailSafeCompiler";
158 nodes << fmt::format(
"{{ 0x{sindex:X}, {info}, {cleft}, {cright} }}",
160 = (tree.SplitIndex(nid) |(
static_cast<uint32_t
>(tree.DefaultLeft(nid)) << 31U)),
161 "info"_a = treelite::compiler::common_util::ToStringHighPrecision(tree.Threshold(nid)),
162 "cleft"_a = tree.LeftChild(nid),
163 "cright"_a = tree.RightChild(nid));
166 node_count += tree.num_nodes;
167 nodes_row_ptr << std::to_string(node_count);
169 return std::make_pair(fmt::format(
"const struct Node nodes[] = {{\n{}\n}};", nodes.str()),
170 fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
171 nodes_row_ptr.str()));
175 inline std::pair<std::vector<char>, std::string> FormatNodesArrayELF(
177 std::vector<char> nodes_elf;
183 nodes_row_ptr <<
"0";
184 for (
const auto& tree : model.
trees) {
185 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
186 if (tree.IsLeaf(nid)) {
187 TREELITE_CHECK(!tree.HasLeafVector(nid))
188 <<
"multi-class random forest classifier is not supported in FailSafeCompiler";
189 val = {0,
static_cast<float>(tree.LeafValue(nid)), -1, -1};
191 TREELITE_CHECK(tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical
192 && !tree.HasMatchingCategories(nid))
193 <<
"categorical splits are not supported in FailSafeCompiler";
194 val = {(tree.SplitIndex(nid) | (
static_cast<uint32_t
>(tree.DefaultLeft(nid)) << 31)),
195 static_cast<float>(tree.Threshold(nid)), tree.LeftChild(nid), tree.RightChild(nid)};
197 const size_t beg = nodes_elf.size();
198 nodes_elf.resize(beg +
sizeof(NodeStructValue));
199 std::memcpy(&nodes_elf[beg], &val,
sizeof(NodeStructValue));
201 node_count += tree.num_nodes;
202 nodes_row_ptr << std::to_string(node_count);
206 return std::make_pair(nodes_elf, fmt::format(
"const int nodes_row_ptr[] = {{\n{}\n}};",
207 nodes_row_ptr.str()));
213 std::set<treelite::Operator> ops;
214 for (
const auto& tree : model.
trees) {
215 for (
int nid = 0; nid < tree.num_nodes; ++nid) {
216 if (!tree.IsLeaf(nid)) {
217 ops.insert(tree.ComparisonOp(nid));
222 TREELITE_CHECK_EQ(ops.size(), 1)
223 <<
"FailSafeCompiler only supports models where all splits use identical comparison operator.";
229 inline bool EndsWith(
const std::string& str,
const std::string& suffix) {
230 return (str.size() >= suffix.size()
231 && str.compare(str.length() - suffix.size(), suffix.size(), suffix) == 0);
244 TREELITE_CHECK(model_ptr.GetThresholdType() == TypeInfo::kFloat32
245 && model_ptr.GetLeafOutputType() == TypeInfo::kFloat32)
246 <<
"Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs";
250 cm.backend =
"native";
253 num_class_ = model.task_param.num_class;
254 TREELITE_CHECK(!model.average_tree_output)
255 <<
"Averaging tree output is not supported in FailSafeCompiler";
256 TREELITE_CHECK(model.task_type == TaskType::kBinaryClfRegr
257 || model.task_type == TaskType::kMultiClfGrovePerClass)
258 <<
"Model task type unsupported by FailSafeCompiler";
259 TREELITE_CHECK_EQ(model.task_param.leaf_vector_size, 1)
260 <<
"Model with leaf vectors is not support by FailSafeCompiler";
261 pred_tranform_func_ = PredTransformFunction(
"native", model_ptr);
264 const char* predict_function_signature
266 "size_t predict_multiclass(union Entry* data, int pred_margin, float* result)" 267 :
"float predict(union Entry* data, int pred_margin)";
269 std::ostringstream main_program;
270 std::string accumulator_definition
272 ? fmt::format(
"float sum[{num_class}] = {{0.0f}}",
273 "num_class"_a = num_class_)
274 : std::string(
"float sum = 0.0f"));
276 std::string output_statement
278 ? fmt::format(
"sum[tree_id % {num_class}] += tree[nid].info.leaf_value;",
279 "num_class"_a = num_class_)
280 : std::string(
"sum += tree[nid].info.leaf_value;"));
282 std::string return_statement
284 ? fmt::format(return_multiclass_template,
285 "num_class"_a = num_class_,
287 = compiler::common_util::ToStringHighPrecision(model.param.global_bias))
288 : fmt::format(return_template,
290 = compiler::common_util::ToStringHighPrecision(model.param.global_bias)));
292 std::string nodes, nodes_row_ptr;
293 std::vector<char> nodes_elf;
294 if (param_.dump_array_as_elf > 0) {
295 if (param_.verbose > 0) {
296 TREELITE_LOG(INFO) <<
"Dumping arrays as an ELF relocatable object...";
298 std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model);
300 std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model);
304 const std::string query_functions_definition
305 = fmt::format(native::query_functions_definition_template,
306 "num_class"_a = num_class_,
307 "num_feature"_a = num_feature_,
314 main_program << fmt::format(main_template,
315 "nodes_row_ptr"_a = nodes_row_ptr,
316 "query_functions_definition"_a = query_functions_definition,
317 "pred_transform_function"_a = pred_tranform_func_,
318 "predict_function_signature"_a = predict_function_signature,
319 "num_tree"_a = model.trees.size(),
320 "compare_op"_a = GetCommonOp(model),
321 "accumulator_definition"_a = accumulator_definition,
322 "output_statement"_a = output_statement,
323 "return_statement"_a = return_statement);
327 if (param_.dump_array_as_elf > 0) {
334 const std::string query_functions_prototype
335 = fmt::format(native::query_functions_prototype_template,
336 "dllexport"_a = DLLEXPORT_KEYWORD);
338 "dllexport"_a = DLLEXPORT_KEYWORD,
339 "query_functions_prototype"_a = query_functions_prototype,
340 "predict_function_signature"_a = predict_function_signature));
344 rapidjson::StringBuffer os;
345 rapidjson::Writer<rapidjson::StringBuffer> writer(os);
347 writer.StartObject();
348 writer.Key(
"target");
349 writer.String(param_.native_lib_name.data(), param_.native_lib_name.size());
350 writer.Key(
"sources");
352 std::vector<std::string> extra_file_list;
353 for (
const auto& kv : files_) {
354 if (EndsWith(kv.first,
".c")) {
355 const size_t line_count
356 = std::count(kv.second.content.begin(), kv.second.content.end(),
'\n');
357 writer.StartObject();
359 std::string name = kv.first.substr(0, kv.first.length() - 2);
360 writer.String(name.data(), name.size());
361 writer.Key(
"length");
362 writer.Uint64(line_count);
364 }
else if (EndsWith(kv.first,
".o")) {
365 extra_file_list.push_back(kv.first);
369 if (!extra_file_list.empty()) {
372 for (
const auto& extra_file : extra_file_list) {
373 writer.String(extra_file.data(), extra_file.size());
381 cm.files = std::move(files_);
392 unsigned int num_class_;
393 std::string pred_tranform_func_;
394 std::unordered_map<std::string, CompiledModel::FileEntry> files_;
397 FailSafeCompiler::FailSafeCompiler(
const CompilerParam& param)
398 : pimpl_(std::make_unique<FailSafeCompilerImpl>(param)) {
400 TREELITE_LOG(INFO) <<
"Using FailSafeCompiler";
403 TREELITE_LOG(INFO) <<
"Warning: 'annotate_in' parameter is not applicable for " 407 TREELITE_LOG(INFO) <<
"Warning: 'quantize' parameter is not applicable for " 411 TREELITE_LOG(INFO) <<
"Warning: 'parallel_comp' parameter is not applicable for " 415 TREELITE_LOG(INFO) <<
"Warning: 'code_folding_req' parameter is not applicable " 416 "for FailSafeCompiler";
420 FailSafeCompiler::~FailSafeCompiler() =
default;
423 FailSafeCompiler::Compile(
const Model& model) {
424 return pimpl_->Compile(model);
428 FailSafeCompiler::QueryParam()
const {
429 return pimpl_->QueryParam();
Parameters for tree compiler.
std::string OpName(Operator op)
get string representation of comparison operator
parameters for tree compiler
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
model structure for tree ensemble
logging facility for Treelite
float global_bias
global bias of the model
Interface of compiler that compiles a tree ensemble model.
template for main function
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
std::string annotate_in
name of model annotation file. Use the class treelite.Annotator to generate this file.
double code_folding_req
parameter for folding rarely visited subtrees (no if/else blocks); all nodes whose data counts are lo...
std::string TypeInfoToString(treelite::TypeInfo type)
Get string representation of type info.
C code generator (fail-safe). The generated code will mimic prediction logic found in XGBoost...
int parallel_comp
option to enable parallel compilation; if set to nonzero, the trees will be evely distributed into [p...
thin wrapper for tree ensemble model
int quantize
whether to quantize threshold points (0: no, >0: yes)
int verbose
if >0, produce extra messages
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function