9 #include <unordered_map> 11 #include "pred_transform.h" 13 #include "./native/pred_transform.h" 15 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)} 20 using PredTransformFuncGenerator
21 = std::string (*)(
const std::string&,
const Model&);
24 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \ 26 FUNC_NAME(const std::string& backend, const Model& model) { \ 27 if (backend == "native") { \ 28 return treelite::compiler::native::pred_transform::FUNC_NAME(model); \ 30 LOG(FATAL) << "Unrecognized backend: " << backend; \ 31 return std::string(); \ 35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
44 const std::unordered_map<std::string, PredTransformFuncGenerator>
46 PRED_TRANSFORM_FUNC(identity),
47 PRED_TRANSFORM_FUNC(sigmoid),
48 PRED_TRANSFORM_FUNC(exponential),
49 PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
68 const std::unordered_map<std::string, PredTransformFuncGenerator>
69 pred_transform_multiclass_db = {
70 PRED_TRANSFORM_FUNC(identity_multiclass),
71 PRED_TRANSFORM_FUNC(max_index),
72 PRED_TRANSFORM_FUNC(softmax),
73 PRED_TRANSFORM_FUNC(multiclass_ova)
97 treelite::compiler::PredTransformFunction(
const std::string& backend,
99 if (model.num_output_group > 1) {
100 auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
101 if (it == pred_transform_multiclass_db.end()) {
102 std::ostringstream oss;
103 for (
const auto& e : pred_transform_multiclass_db) {
104 oss <<
"'" << e.first <<
"', ";
106 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 107 <<
"For multi-class classification, you should set " 108 <<
"`pred_transform` to one of the following: " 109 <<
"{ " << oss.str() <<
" }";
111 return (it->second)(backend, model);
113 auto it = pred_transform_db.find(model.param.pred_transform);
114 if (it == pred_transform_db.end()) {
115 std::ostringstream oss;
116 for (
const auto& e : pred_transform_db) {
117 oss <<
"'" << e.first <<
"', ";
119 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 120 <<
"For any task that is NOT multi-class classification, you " 121 <<
"should set `pred_transform` to one of the following: " 122 <<
"{ " << oss.str() <<
" }";
124 return (it->second)(backend, model);
thin wrapper for tree ensemble model