Treelite
pred_transform.h
1 
8 #ifndef TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
9 #define TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
10 
11 #include <treelite/logging.h>
12 #include <fmt/format.h>
13 #include <string>
14 #include "./typeinfo_ctypes.h"
15 
16 using namespace fmt::literals;
17 
18 namespace treelite {
19 namespace compiler {
20 namespace native {
21 namespace pred_transform {
22 
23 inline std::string identity(const Model& model) {
24  return fmt::format(
25 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
26  return margin;
27 }})TREELITETEMPLATE",
28 "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
29 }
30 
31 inline std::string signed_square(const Model& model) {
32  const TypeInfo threshold_type = model.GetThresholdType();
33  return fmt::format(
34  R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
35  return {copysign}(margin * margin, margin);
36 }})TREELITETEMPLATE",
37  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
38  "copysign"_a = native::CCopySignForTypeInfo(threshold_type));
39 }
40 
41 inline std::string hinge(const Model& model) {
42  const TypeInfo threshold_type = model.GetThresholdType();
43  return fmt::format(
44  R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
45  if (margin > 0) {{
46  return ({threshold_type})(1);
47  }} else {{
48  return ({threshold_type})(0);
49  }}
50 }})TREELITETEMPLATE",
51  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
52 }
53 
54 inline std::string sigmoid(const Model& model) {
55  const float alpha = model.param.sigmoid_alpha;
56  const TypeInfo threshold_type = model.GetThresholdType();
57  TREELITE_CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive";
58  return fmt::format(
59 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
60  const {threshold_type} alpha = ({threshold_type}){alpha};
61  return ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * margin));
62 }})TREELITETEMPLATE",
63  "alpha"_a = alpha,
64  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
65  "exp"_a = native::CExpForTypeInfo(threshold_type));
66 }
67 
68 inline std::string exponential(const Model& model) {
69  const TypeInfo threshold_type = model.GetThresholdType();
70  return fmt::format(
71 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
72  return {exp}(margin);
73 }})TREELITETEMPLATE",
74  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
75  "exp"_a = native::CExpForTypeInfo(threshold_type));
76 }
77 
78 inline std::string exponential_standard_ratio(const Model& model) {
79  const float ratio_c = model.param.ratio_c;
80  const TypeInfo threshold_type = model.GetThresholdType();
81  return fmt::format(
82 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
83  return {exp2}(-margin / ({threshold_type}){ratio_c});
84 }})TREELITETEMPLATE",
85  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
86  "ratio_c"_a = ratio_c,
87  "exp2"_a = native::CExp2ForTypeInfo(threshold_type));
88 }
89 
90 inline std::string logarithm_one_plus_exp(const Model& model) {
91  const TypeInfo threshold_type = model.GetThresholdType();
92  return fmt::format(
93 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
94  return {log1p}({exp}(margin));
95 }})TREELITETEMPLATE",
96  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
97  "exp"_a = native::CExpForTypeInfo(threshold_type),
98  "log1p"_a = native::CLog1PForTypeInfo(threshold_type));
99 }
100 
101 inline std::string identity_multiclass(const Model& model) {
102  TREELITE_CHECK_GT(model.task_param.num_class, 1)
103  << "identity_multiclass: model is not a proper multi-class classifier";
104  return fmt::format(
105 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
106  return {num_class};
107 }})TREELITETEMPLATE",
108  "num_class"_a = model.task_param.num_class,
109  "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
110 }
111 
112 inline std::string max_index(const Model& model) {
113  TREELITE_CHECK_GT(model.task_param.num_class, 1)
114  << "max_index: model is not a proper multi-class classifier";
115  const TypeInfo threshold_type = model.GetThresholdType();
116  return fmt::format(
117 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
118  const int num_class = {num_class};
119  int max_index = 0;
120  {threshold_type} max_margin = pred[0];
121  for (int k = 1; k < num_class; ++k) {{
122  if (pred[k] > max_margin) {{
123  max_margin = pred[k];
124  max_index = k;
125  }}
126  }}
127  pred[0] = ({threshold_type})max_index;
128  return 1;
129 }})TREELITETEMPLATE",
130  "num_class"_a = model.task_param.num_class,
131  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
132 }
133 
134 inline std::string softmax(const Model& model) {
135  TREELITE_CHECK_GT(model.task_param.num_class, 1)
136  << "softmax: model is not a proper multi-class classifier";
137  const TypeInfo threshold_type = model.GetThresholdType();
138  return fmt::format(
139 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
140  const int num_class = {num_class};
141  {threshold_type} max_margin = pred[0];
142  double norm_const = 0.0;
143  {threshold_type} t;
144  for (int k = 1; k < num_class; ++k) {{
145  if (pred[k] > max_margin) {{
146  max_margin = pred[k];
147  }}
148  }}
149  for (int k = 0; k < num_class; ++k) {{
150  t = {exp}(pred[k] - max_margin);
151  norm_const += t;
152  pred[k] = t;
153  }}
154  for (int k = 0; k < num_class; ++k) {{
155  pred[k] /= ({threshold_type})norm_const;
156  }}
157  return (size_t)num_class;
158 }})TREELITETEMPLATE",
159  "num_class"_a = model.task_param.num_class,
160  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
161  "exp"_a = native::CExpForTypeInfo(threshold_type));
162 }
163 
164 inline std::string multiclass_ova(const Model& model) {
165  TREELITE_CHECK(model.task_param.num_class > 1)
166  << "multiclass_ova: model is not a proper multi-class classifier";
167  const unsigned int num_class = model.task_param.num_class;
168  const float alpha = model.param.sigmoid_alpha;
169  const TypeInfo threshold_type = model.GetThresholdType();
170  TREELITE_CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive";
171  return fmt::format(
172 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
173  const {threshold_type} alpha = ({threshold_type}){alpha};
174  const int num_class = {num_class};
175  for (int k = 0; k < num_class; ++k) {{
176  pred[k] = ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * pred[k]));
177  }}
178  return (size_t)num_class;
179 }})TREELITETEMPLATE",
180  "num_class"_a = num_class, "alpha"_a = alpha,
181  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
182  "exp"_a = native::CExpForTypeInfo(threshold_type));
183 }
184 
185 } // namespace pred_transform
186 } // namespace native
187 } // namespace compiler
188 } // namespace treelite
189 #endif // TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:23
Look up C symbols corresponding to TypeInfo.