8 #include "pred_transform.h" 10 #include <unordered_map> 12 #include "./native/pred_transform.h" 14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)} 19 using PredTransformFuncGenerator
20 = std::string (*)(
const std::string&,
const Model&);
23 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \ 25 FUNC_NAME(const std::string& backend, const Model& model) { \ 26 if (backend == "native") { \ 27 return treelite::compiler::native::pred_transform::FUNC_NAME(model); \ 29 TREELITE_LOG(FATAL) << "Unrecognized backend: " << backend; \ 30 return std::string(); \ 34 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(signed_square)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(hinge)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
43 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
45 const std::unordered_map<std::string, PredTransformFuncGenerator>
47 PRED_TRANSFORM_FUNC(identity),
48 PRED_TRANSFORM_FUNC(signed_square),
49 PRED_TRANSFORM_FUNC(hinge),
50 PRED_TRANSFORM_FUNC(sigmoid),
51 PRED_TRANSFORM_FUNC(exponential),
52 PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
77 const std::unordered_map<std::string, PredTransformFuncGenerator>
78 pred_transform_multiclass_db = {
79 PRED_TRANSFORM_FUNC(identity_multiclass),
80 PRED_TRANSFORM_FUNC(max_index),
81 PRED_TRANSFORM_FUNC(softmax),
82 PRED_TRANSFORM_FUNC(multiclass_ova)
106 treelite::compiler::PredTransformFunction(
const std::string& backend,
107 const Model& model) {
108 ModelParam param = model.param;
109 if (model.task_param.num_class > 1) {
110 auto it = pred_transform_multiclass_db.find(param.pred_transform);
111 if (it == pred_transform_multiclass_db.end()) {
112 std::ostringstream oss;
113 for (
const auto& e : pred_transform_multiclass_db) {
114 oss <<
"'" << e.first <<
"', ";
116 TREELITE_LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 117 <<
"For multi-class classification, you should set " 118 <<
"`pred_transform` to one of the following: " 119 <<
"{ " << oss.str() <<
" }";
121 return (it->second)(backend, model);
123 auto it = pred_transform_db.find(param.pred_transform);
124 if (it == pred_transform_db.end()) {
125 std::ostringstream oss;
126 for (
const auto& e : pred_transform_db) {
127 oss <<
"'" << e.first <<
"', ";
129 TREELITE_LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 130 <<
"For any task that is NOT multi-class classification, you " 131 <<
"should set `pred_transform` to one of the following: " 132 <<
"{ " << oss.str() <<
" }";
134 return (it->second)(backend, model);
thin wrapper for tree ensemble model