8 #include "pred_transform.h" 10 #include <unordered_map> 12 #include "./native/pred_transform.h" 14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)} 19 using PredTransformFuncGenerator
20 = std::string (*)(
const std::string&,
const Model&);
23 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \ 25 FUNC_NAME(const std::string& backend, const Model& model) { \ 26 if (backend == "native") { \ 27 return treelite::compiler::native::pred_transform::FUNC_NAME(model); \ 29 TREELITE_LOG(FATAL) << "Unrecognized backend: " << backend; \ 30 return std::string(); \ 34 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(signed_square)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(hinge)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential_standard_ratio)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
43 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
44 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
46 const std::unordered_map<std::string, PredTransformFuncGenerator>
48 PRED_TRANSFORM_FUNC(identity),
49 PRED_TRANSFORM_FUNC(signed_square),
50 PRED_TRANSFORM_FUNC(hinge),
51 PRED_TRANSFORM_FUNC(sigmoid),
52 PRED_TRANSFORM_FUNC(exponential),
53 PRED_TRANSFORM_FUNC(exponential_standard_ratio),
54 PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
83 const std::unordered_map<std::string, PredTransformFuncGenerator>
84 pred_transform_multiclass_db = {
85 PRED_TRANSFORM_FUNC(identity_multiclass),
86 PRED_TRANSFORM_FUNC(max_index),
87 PRED_TRANSFORM_FUNC(softmax),
88 PRED_TRANSFORM_FUNC(multiclass_ova)
112 treelite::compiler::PredTransformFunction(
const std::string& backend,
113 const Model& model) {
114 ModelParam param = model.param;
115 if (model.task_param.num_class > 1) {
116 auto it = pred_transform_multiclass_db.find(param.pred_transform);
117 if (it == pred_transform_multiclass_db.end()) {
118 std::ostringstream oss;
119 for (
const auto& e : pred_transform_multiclass_db) {
120 oss <<
"'" << e.first <<
"', ";
122 TREELITE_LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 123 <<
"For multi-class classification, you should set " 124 <<
"`pred_transform` to one of the following: " 125 <<
"{ " << oss.str() <<
" }";
127 return (it->second)(backend, model);
129 auto it = pred_transform_db.find(param.pred_transform);
130 if (it == pred_transform_db.end()) {
131 std::ostringstream oss;
132 for (
const auto& e : pred_transform_db) {
133 oss <<
"'" << e.first <<
"', ";
135 TREELITE_LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 136 <<
"For any task that is NOT multi-class classification, you " 137 <<
"should set `pred_transform` to one of the following: " 138 <<
"{ " << oss.str() <<
" }";
140 return (it->second)(backend, model);
thin wrapper for tree ensemble model