8 #ifndef TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_ 9 #define TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_ 12 #include <fmt/format.h> 21 namespace pred_transform {
23 inline std::string identity(
const Model& model) {
25 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 28 "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
31 inline std::string signed_square(
const Model& model) {
32 const TypeInfo threshold_type = model.GetThresholdType();
34 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 35 return {copysign}(margin * margin, margin); 37 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
38 "copysign"_a = native::CCopySignForTypeInfo(threshold_type));
41 inline std::string hinge(
const Model& model) {
42 const TypeInfo threshold_type = model.GetThresholdType();
44 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 46 return ({threshold_type})(1); 48 return ({threshold_type})(0); 51 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
54 inline std::string sigmoid(
const Model& model) {
55 const float alpha = model.param.sigmoid_alpha;
56 const TypeInfo threshold_type = model.GetThresholdType();
57 TREELITE_CHECK_GT(alpha, 0.0f) <<
"sigmoid: alpha must be strictly positive";
59 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 60 const {threshold_type} alpha = ({threshold_type}){alpha}; 61 return ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * margin)); 64 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
65 "exp"_a = native::CExpForTypeInfo(threshold_type));
68 inline std::string exponential(
const Model& model) {
69 const TypeInfo threshold_type = model.GetThresholdType();
71 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 74 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
75 "exp"_a = native::CExpForTypeInfo(threshold_type));
78 inline std::string exponential_standard_ratio(
const Model& model) {
79 const float ratio_c = model.param.ratio_c;
80 const TypeInfo threshold_type = model.GetThresholdType();
82 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 83 return {exp2}(-margin / ({threshold_type}){ratio_c}); 85 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
86 "ratio_c"_a = ratio_c,
87 "exp2"_a = native::CExp2ForTypeInfo(threshold_type));
90 inline std::string logarithm_one_plus_exp(
const Model& model) {
91 const TypeInfo threshold_type = model.GetThresholdType();
93 R
"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ 94 return {log1p}({exp}(margin)); 96 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
97 "exp"_a = native::CExpForTypeInfo(threshold_type),
98 "log1p"_a = native::CLog1PForTypeInfo(threshold_type));
101 inline std::string identity_multiclass(
const Model& model) {
102 TREELITE_CHECK_GT(model.task_param.num_class, 1)
103 <<
"identity_multiclass: model is not a proper multi-class classifier";
105 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 107 }})TREELITETEMPLATE", 108 "num_class"_a = model.task_param.num_class,
109 "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
112 inline std::string max_index(
const Model& model) {
113 TREELITE_CHECK_GT(model.task_param.num_class, 1)
114 <<
"max_index: model is not a proper multi-class classifier";
115 const TypeInfo threshold_type = model.GetThresholdType();
117 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 118 const int num_class = {num_class}; 120 {threshold_type} max_margin = pred[0]; 121 for (int k = 1; k < num_class; ++k) {{ 122 if (pred[k] > max_margin) {{ 123 max_margin = pred[k]; 127 pred[0] = ({threshold_type})max_index; 129 }})TREELITETEMPLATE", 130 "num_class"_a = model.task_param.num_class,
131 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
134 inline std::string softmax(
const Model& model) {
135 TREELITE_CHECK_GT(model.task_param.num_class, 1)
136 <<
"softmax: model is not a proper multi-class classifier";
137 const TypeInfo threshold_type = model.GetThresholdType();
139 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 140 const int num_class = {num_class}; 141 {threshold_type} max_margin = pred[0]; 142 double norm_const = 0.0; 144 for (int k = 1; k < num_class; ++k) {{ 145 if (pred[k] > max_margin) {{ 146 max_margin = pred[k]; 149 for (int k = 0; k < num_class; ++k) {{ 150 t = {exp}(pred[k] - max_margin); 154 for (int k = 0; k < num_class; ++k) {{ 155 pred[k] /= ({threshold_type})norm_const; 157 return (size_t)num_class; 158 }})TREELITETEMPLATE", 159 "num_class"_a = model.task_param.num_class,
160 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
161 "exp"_a = native::CExpForTypeInfo(threshold_type));
164 inline std::string multiclass_ova(
const Model& model) {
165 TREELITE_CHECK(model.task_param.num_class > 1)
166 <<
"multiclass_ova: model is not a proper multi-class classifier";
167 const unsigned int num_class = model.task_param.num_class;
168 const float alpha = model.param.sigmoid_alpha;
169 const TypeInfo threshold_type = model.GetThresholdType();
170 TREELITE_CHECK_GT(alpha, 0.0f) <<
"multiclass_ova: alpha must be strictly positive";
172 R
"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ 173 const {threshold_type} alpha = ({threshold_type}){alpha}; 174 const int num_class = {num_class}; 175 for (int k = 0; k < num_class; ++k) {{ 176 pred[k] = ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * pred[k])); 178 return (size_t)num_class; 179 }})TREELITETEMPLATE", 180 "num_class"_a = num_class,
"alpha"_a = alpha,
181 "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
182 "exp"_a = native::CExpForTypeInfo(threshold_type));
189 #endif // TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
Look up C symbols corresponding to TypeInfo.