10 #include <unordered_map> 11 #include "pred_transform.h" 13 #include "./native/pred_transform.h" 14 #include "./java/pred_transform.h" 16 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)} 21 using PredTransformFuncGenerator
22 = std::string (*)(
const std::string&,
const Model&);
25 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \ 27 FUNC_NAME(const std::string& backend, const Model& model) { \ 28 if (backend == "native") { \ 29 return treelite::compiler::pred_transform::native::FUNC_NAME(model); \ 30 } else if (backend == "java") { \ 31 return treelite::compiler::pred_transform::java::FUNC_NAME(model); \ 33 LOG(FATAL) << "Unrecognized backend: " << backend; \ 37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
43 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
44 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
46 const std::unordered_map<std::
string, PredTransformFuncGenerator>
48 PRED_TRANSFORM_FUNC(identity),
49 PRED_TRANSFORM_FUNC(sigmoid),
50 PRED_TRANSFORM_FUNC(exponential),
51 PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
70 const std::unordered_map<std::string, PredTransformFuncGenerator>
71 pred_transform_multiclass_db = {
72 PRED_TRANSFORM_FUNC(identity_multiclass),
73 PRED_TRANSFORM_FUNC(max_index),
74 PRED_TRANSFORM_FUNC(softmax),
75 PRED_TRANSFORM_FUNC(multiclass_ova)
99 treelite::compiler::PredTransformFunction(
const std::string& backend,
100 const Model& model) {
101 if (model.num_output_group > 1) {
102 auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
103 if (it == pred_transform_multiclass_db.end()) {
104 std::ostringstream oss;
105 for (
const auto& e : pred_transform_multiclass_db) {
106 oss <<
"'" << e.first <<
"', ";
108 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 109 <<
"For multi-class classification, you should set " 110 <<
"`pred_transform` to one of the following: " 111 <<
"{ " << oss.str() <<
" }";
113 return (it->second)(backend, model);
115 auto it = pred_transform_db.find(model.param.pred_transform);
116 if (it == pred_transform_db.end()) {
117 std::ostringstream oss;
118 for (
const auto& e : pred_transform_db) {
119 oss <<
"'" << e.first <<
"', ";
121 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 122 <<
"For any task that is NOT multi-class classification, you " 123 <<
"should set `pred_transform` to one of the following: " 124 <<
"{ " << oss.str() <<
" }";
126 return (it->second)(backend, model);
thin wrapper for tree ensemble model