Treelite
pred_transform.h
1 
8 #ifndef TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
9 #define TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
10 
11 #include <treelite/logging.h>
12 #include <fmt/format.h>
13 #include <string>
14 #include "./typeinfo_ctypes.h"
15 
16 using namespace fmt::literals;
17 
18 namespace treelite {
19 namespace compiler {
20 namespace native {
21 namespace pred_transform {
22 
23 inline std::string identity(const Model& model) {
24  return fmt::format(
25 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
26  return margin;
27 }})TREELITETEMPLATE",
28 "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
29 }
30 
31 inline std::string signed_square(const Model& model) {
32  const TypeInfo threshold_type = model.GetThresholdType();
33  return fmt::format(
34  R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
35  return {copysign}(margin * margin, margin);
36 }})TREELITETEMPLATE",
37  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
38  "copysign"_a = native::CCopySignForTypeInfo(threshold_type));
39 }
40 
41 inline std::string hinge(const Model& model) {
42  const TypeInfo threshold_type = model.GetThresholdType();
43  return fmt::format(
44  R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
45  if (margin > 0) {{
46  return ({threshold_type})(1);
47  }} else {{
48  return ({threshold_type})(0);
49  }}
50 }})TREELITETEMPLATE",
51  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
52 }
53 
54 inline std::string sigmoid(const Model& model) {
55  const float alpha = model.param.sigmoid_alpha;
56  const TypeInfo threshold_type = model.GetThresholdType();
57  TREELITE_CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive";
58  return fmt::format(
59 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
60  const {threshold_type} alpha = ({threshold_type}){alpha};
61  return ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * margin));
62 }})TREELITETEMPLATE",
63  "alpha"_a = alpha,
64  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
65  "exp"_a = native::CExpForTypeInfo(threshold_type));
66 }
67 
68 inline std::string exponential(const Model& model) {
69  const TypeInfo threshold_type = model.GetThresholdType();
70  return fmt::format(
71 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
72  return {exp}(margin);
73 }})TREELITETEMPLATE",
74  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
75  "exp"_a = native::CExpForTypeInfo(threshold_type));
76 }
77 
78 inline std::string logarithm_one_plus_exp(const Model& model) {
79  const TypeInfo threshold_type = model.GetThresholdType();
80  return fmt::format(
81 R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
82  return {log1p}({exp}(margin));
83 }})TREELITETEMPLATE",
84  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
85  "exp"_a = native::CExpForTypeInfo(threshold_type),
86  "log1p"_a = native::CLog1PForTypeInfo(threshold_type));
87 }
88 
89 inline std::string identity_multiclass(const Model& model) {
90  TREELITE_CHECK_GT(model.task_param.num_class, 1)
91  << "identity_multiclass: model is not a proper multi-class classifier";
92  return fmt::format(
93 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
94  return {num_class};
95 }})TREELITETEMPLATE",
96  "num_class"_a = model.task_param.num_class,
97  "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
98 }
99 
100 inline std::string max_index(const Model& model) {
101  TREELITE_CHECK_GT(model.task_param.num_class, 1)
102  << "max_index: model is not a proper multi-class classifier";
103  const TypeInfo threshold_type = model.GetThresholdType();
104  return fmt::format(
105 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
106  const int num_class = {num_class};
107  int max_index = 0;
108  {threshold_type} max_margin = pred[0];
109  for (int k = 1; k < num_class; ++k) {{
110  if (pred[k] > max_margin) {{
111  max_margin = pred[k];
112  max_index = k;
113  }}
114  }}
115  pred[0] = ({threshold_type})max_index;
116  return 1;
117 }})TREELITETEMPLATE",
118  "num_class"_a = model.task_param.num_class,
119  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type));
120 }
121 
122 inline std::string softmax(const Model& model) {
123  TREELITE_CHECK_GT(model.task_param.num_class, 1)
124  << "softmax: model is not a proper multi-class classifier";
125  const TypeInfo threshold_type = model.GetThresholdType();
126  return fmt::format(
127 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
128  const int num_class = {num_class};
129  {threshold_type} max_margin = pred[0];
130  double norm_const = 0.0;
131  {threshold_type} t;
132  for (int k = 1; k < num_class; ++k) {{
133  if (pred[k] > max_margin) {{
134  max_margin = pred[k];
135  }}
136  }}
137  for (int k = 0; k < num_class; ++k) {{
138  t = {exp}(pred[k] - max_margin);
139  norm_const += t;
140  pred[k] = t;
141  }}
142  for (int k = 0; k < num_class; ++k) {{
143  pred[k] /= ({threshold_type})norm_const;
144  }}
145  return (size_t)num_class;
146 }})TREELITETEMPLATE",
147  "num_class"_a = model.task_param.num_class,
148  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
149  "exp"_a = native::CExpForTypeInfo(threshold_type));
150 }
151 
152 inline std::string multiclass_ova(const Model& model) {
153  TREELITE_CHECK(model.task_param.num_class > 1)
154  << "multiclass_ova: model is not a proper multi-class classifier";
155  const unsigned int num_class = model.task_param.num_class;
156  const float alpha = model.param.sigmoid_alpha;
157  const TypeInfo threshold_type = model.GetThresholdType();
158  TREELITE_CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive";
159  return fmt::format(
160 R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{
161  const {threshold_type} alpha = ({threshold_type}){alpha};
162  const int num_class = {num_class};
163  for (int k = 0; k < num_class; ++k) {{
164  pred[k] = ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * pred[k]));
165  }}
166  return (size_t)num_class;
167 }})TREELITETEMPLATE",
168  "num_class"_a = num_class, "alpha"_a = alpha,
169  "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
170  "exp"_a = native::CExpForTypeInfo(threshold_type));
171 }
172 
173 } // namespace pred_transform
174 } // namespace native
175 } // namespace compiler
176 } // namespace treelite
177 #endif // TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
logging facility for Treelite
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Look up C symbols corresponding to TypeInfo.