8 #ifndef TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_ 9 #define TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_ 11 #include <dmlc/logging.h> 12 #include <fmt/format.h> 21 namespace pred_transform {
23 inline std::string identity(
const Model& model) {
25 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 28 "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
31 inline std::string signed_square(
const Model& model) {
32 const TypeInfo threshold_type = model.GetThresholdType();
34 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 35 return {copysign}(margin * margin, margin); 37 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
38 "copysign"_a = native::CCopySignForTypeInfo(threshold_type));
41 inline std::string hinge(
const Model& model) {
42 const TypeInfo threshold_type = model.GetThresholdType();
44 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 46 return ({threshold_type})(1); 48 return ({threshold_type})(0); 51 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
54 inline std::string sigmoid(
const Model& model) {
55 const float alpha = model.param.sigmoid_alpha;
56 const TypeInfo threshold_type = model.GetThresholdType();
57 CHECK_GT(alpha, 0.0f) <<
"sigmoid: alpha must be strictly positive";
59 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 60 const {threshold_type} alpha = ({threshold_type}){alpha}; 61 return ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * margin)); 64 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
65 "exp"_a = native::CExpForTypeInfo(threshold_type));
68 inline std::string exponential(
const Model& model) {
69 const TypeInfo threshold_type = model.GetThresholdType();
71 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 74 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
75 "exp"_a = native::CExpForTypeInfo(threshold_type));
78 inline std::string logarithm_one_plus_exp(
const Model& model) {
79 const TypeInfo threshold_type = model.GetThresholdType();
81 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 82 return {log1p}({exp}(margin)); 84 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
85 "exp"_a = native::CExpForTypeInfo(threshold_type),
86 "log1p"_a = native::CLog1PForTypeInfo(threshold_type));
89 inline std::string identity_multiclass(
const Model& model) {
90 CHECK_GT(model.task_param.num_class, 1)
91 <<
"identity_multiclass: model is not a proper multi-class classifier";
93 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 96 "num_class"_a = model.task_param.num_class,
97 "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
100 inline std::string max_index(
const Model& model) {
101 CHECK_GT(model.task_param.num_class, 1)
102 <<
"max_index: model is not a proper multi-class classifier";
103 const TypeInfo threshold_type = model.GetThresholdType();
105 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 106 const int num_class = {num_class}; 108 {threshold_type} max_margin = pred[0]; 109 for (int k = 1; k < num_class; ++k) {{ 110 if (pred[k] > max_margin) {{ 111 max_margin = pred[k]; 115 pred[0] = ({threshold_type})max_index; 117 }})TREELITETEMPLATE", 118 "num_class"_a = model.task_param.num_class,
119 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
122 inline std::string softmax(
const Model& model) {
123 CHECK_GT(model.task_param.num_class, 1)
124 <<
"softmax: model is not a proper multi-class classifier";
125 const TypeInfo threshold_type = model.GetThresholdType();
127 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 128 const int num_class = {num_class}; 129 {threshold_type} max_margin = pred[0]; 130 double norm_const = 0.0; 132 for (int k = 1; k < num_class; ++k) {{ 133 if (pred[k] > max_margin) {{ 134 max_margin = pred[k]; 137 for (int k = 0; k < num_class; ++k) {{ 138 t = {exp}(pred[k] - max_margin); 142 for (int k = 0; k < num_class; ++k) {{ 143 pred[k] /= ({threshold_type})norm_const; 145 return (size_t)num_class; 146 }})TREELITETEMPLATE", 147 "num_class"_a = model.task_param.num_class,
148 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
149 "exp"_a = native::CExpForTypeInfo(threshold_type));
152 inline std::string multiclass_ova(
const Model& model) {
153 CHECK(model.task_param.num_class > 1)
154 <<
"multiclass_ova: model is not a proper multi-class classifier";
155 const unsigned int num_class = model.task_param.num_class;
156 const float alpha = model.param.sigmoid_alpha;
157 const TypeInfo threshold_type = model.GetThresholdType();
158 CHECK_GT(alpha, 0.0f) <<
"multiclass_ova: alpha must be strictly positive";
160 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 161 const {threshold_type} alpha = ({threshold_type}){alpha}; 162 const int num_class = {num_class}; 163 for (int k = 0; k < num_class; ++k) {{ 164 pred[k] = ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * pred[k])); 166 return (size_t)num_class; 167 }})TREELITETEMPLATE", 168 "num_class"_a = num_class,
"alpha"_a = alpha,
169 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
170 "exp"_a = native::CExpForTypeInfo(threshold_type));
177 #endif // TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
TypeInfo
Types used by thresholds and leaf outputs.
Look up C symbols corresponding to TypeInfo.