8 #include "pred_transform.h" 10 #include <unordered_map> 12 #include "./native/pred_transform.h" 14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)} 19 using PredTransformFuncGenerator
20 = std::string (*)(
const std::string&,
const Model&);
23 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \ 25 FUNC_NAME(const std::string& backend, const Model& model) { \ 26 if (backend == "native") { \ 27 return treelite::compiler::native::pred_transform::FUNC_NAME(model); \ 29 LOG(FATAL) << "Unrecognized backend: " << backend; \ 30 return std::string(); \ 34 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
43 const std::unordered_map<std::string, PredTransformFuncGenerator>
45 PRED_TRANSFORM_FUNC(identity),
46 PRED_TRANSFORM_FUNC(sigmoid),
47 PRED_TRANSFORM_FUNC(exponential),
48 PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
67 const std::unordered_map<std::string, PredTransformFuncGenerator>
68 pred_transform_multiclass_db = {
69 PRED_TRANSFORM_FUNC(identity_multiclass),
70 PRED_TRANSFORM_FUNC(max_index),
71 PRED_TRANSFORM_FUNC(softmax),
72 PRED_TRANSFORM_FUNC(multiclass_ova)
96 treelite::compiler::PredTransformFunction(
const std::string& backend,
98 if (model.num_output_group > 1) {
99 auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
100 if (it == pred_transform_multiclass_db.end()) {
101 std::ostringstream oss;
102 for (
const auto& e : pred_transform_multiclass_db) {
103 oss <<
"'" << e.first <<
"', ";
105 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 106 <<
"For multi-class classification, you should set " 107 <<
"`pred_transform` to one of the following: " 108 <<
"{ " << oss.str() <<
" }";
110 return (it->second)(backend, model);
112 auto it = pred_transform_db.find(model.param.pred_transform);
113 if (it == pred_transform_db.end()) {
114 std::ostringstream oss;
115 for (
const auto& e : pred_transform_db) {
116 oss <<
"'" << e.first <<
"', ";
118 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 119 <<
"For any task that is NOT multi-class classification, you " 120 <<
"should set `pred_transform` to one of the following: " 121 <<
"{ " << oss.str() <<
" }";
123 return (it->second)(backend, model);
thin wrapper for tree ensemble model