8 #ifndef TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_ 9 #define TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_ 11 #include <dmlc/logging.h> 12 #include <fmt/format.h> 20 namespace pred_transform {
22 inline std::string identity(
const Model& model) {
24 R
"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ 26 }})TREELITETEMPLATE"); 29 inline std::string sigmoid(
const Model& model) {
30 const float alpha = model.param.sigmoid_alpha;
31 CHECK_GT(alpha, 0.0f) <<
"sigmoid: alpha must be strictly positive";
33 R
"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ 34 const float alpha = (float){alpha}; 35 return 1.0f / (1 + expf(-alpha * margin)); 40 inline std::string exponential(
const Model& model) {
42 R
"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ 44 }})TREELITETEMPLATE"); 47 inline std::string logarithm_one_plus_exp(
const Model& model) {
49 R
"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ 50 return log1pf(expf(margin)); 51 }})TREELITETEMPLATE"); 54 inline std::string identity_multiclass(
const Model& model) {
55 CHECK(model.num_output_group > 1)
56 <<
"identity_multiclass: model is not a proper multi-class classifier";
58 R
"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ 61 "num_class"_a = model.num_output_group);
64 inline std::string max_index(
const Model& model) {
65 CHECK(model.num_output_group > 1)
66 <<
"max_index: model is not a proper multi-class classifier";
68 R
"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ 69 const int num_class = {num_class}; 71 float max_margin = pred[0]; 72 for (int k = 1; k < num_class; ++k) {{ 73 if (pred[k] > max_margin) {{ 78 pred[0] = (float)max_index; 81 "num_class"_a = model.num_output_group);
84 inline std::string softmax(
const Model& model) {
85 CHECK(model.num_output_group > 1)
86 <<
"softmax: model is not a proper multi-class classifier";
88 R
"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ 89 const int num_class = {num_class}; 90 float max_margin = pred[0]; 91 double norm_const = 0.0; 93 for (int k = 1; k < num_class; ++k) {{ 94 if (pred[k] > max_margin) {{ 98 for (int k = 0; k < num_class; ++k) {{ 99 t = expf(pred[k] - max_margin); 103 for (int k = 0; k < num_class; ++k) {{ 104 pred[k] /= (float)norm_const; 106 return (size_t)num_class; 107 }})TREELITETEMPLATE", 108 "num_class"_a = model.num_output_group);
111 inline std::string multiclass_ova(
const Model& model) {
112 CHECK(model.num_output_group > 1)
113 <<
"multiclass_ova: model is not a proper multi-class classifier";
114 const int num_class = model.num_output_group;
115 const float alpha = model.param.sigmoid_alpha;
116 CHECK_GT(alpha, 0.0f) <<
"multiclass_ova: alpha must be strictly positive";
118 R
"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ 119 const float alpha = (float){alpha}; 120 const int num_class = {num_class}; 121 for (int k = 0; k < num_class; ++k) {{ 122 pred[k] = 1.0f / (1.0f + expf(-alpha * pred[k])); 124 return (size_t)num_class; 125 }})TREELITETEMPLATE", 126 "num_class"_a = model.num_output_group,
"alpha"_a = alpha);
133 #endif // TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_