treelite
pred_transform.h
1 #include <treelite/common.h>
2 
3 namespace treelite {
4 namespace compiler {
5 namespace pred_transform {
6 namespace native {
7 
8 inline std::string identity(const Model& model) {
9  return
10  "static inline float pred_transform(float margin) {\n"
11  " return margin;\n"
12  "}\n";
13 }
14 
15 inline std::string sigmoid(const Model& model) {
16  const float alpha = model.param.sigmoid_alpha;
17  CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive";
18  std::ostringstream oss;
19  oss << "static inline float pred_transform(float margin) {\n"
20  << " const float alpha = (float)" << common::ToString(alpha) << ";\n"
21  << " return 1.0f / (1 + expf(-alpha * margin));\n"
22  << "}\n";
23  return oss.str();
24 }
25 
26 inline std::string exponential(const Model& model) {
27  return
28  "static inline float pred_transform(float margin) {\n"
29  " return expf(margin);\n"
30  "}\n";
31 }
32 
33 inline std::string logarithm_one_plus_exp(const Model& model) {
34  return
35  "static inline float pred_transform(float margin) {\n"
36  " return log1pf(expf(margin));\n"
37  "}\n";
38 }
39 
40 inline std::string identity_multiclass(const Model& model) {
41  CHECK(model.num_output_group > 1)
42  << "identity_multiclass: model is not a proper multi-class classifier";
43  const int num_class = model.num_output_group;
44  std::ostringstream oss;
45  oss << "static inline size_t pred_transform(float* pred) {\n"
46  << " const size_t num_class = " << num_class << ";\n"
47  << " return num_class;\n"
48  << "}\n";
49  return oss.str();
50 }
51 
52 inline std::string max_index(const Model& model) {
53  CHECK(model.num_output_group > 1)
54  << "max_index: model is not a proper multi-class classifier";
55  const int num_class = model.num_output_group;
56  std::ostringstream oss;
57  oss << "static inline size_t pred_transform(float* pred) {\n"
58  << " const int num_class = " << num_class << ";\n"
59  << " int max_index = 0;\n"
60  << " float max_margin = pred[0];\n"
61  << " for (int k = 1; k < num_class; ++k) {\n"
62  << " if (pred[k] > max_margin) {\n"
63  << " max_margin = pred[k];\n"
64  << " max_index = k;\n"
65  << " }\n"
66  << " }\n"
67  << " pred[0] = (float)max_index;\n"
68  << " return 1;\n"
69  << "}\n";
70  return oss.str();
71 }
72 
73 inline std::string softmax(const Model& model) {
74  CHECK(model.num_output_group > 1)
75  << "softmax: model is not a proper multi-class classifier";
76  const int num_class = model.num_output_group;
77  std::ostringstream oss;
78  oss << "static inline size_t pred_transform(float* pred) {\n"
79  << " const int num_class = " << num_class << ";\n"
80  << " float max_margin = pred[0];\n"
81  << " double norm_const = 0.0;\n"
82  << " float t;\n"
83  << " for (int k = 1; k < num_class; ++k) {\n"
84  << " if (pred[k] > max_margin) {\n"
85  << " max_margin = pred[k];\n"
86  << " }\n"
87  << " }\n"
88  << " for (int k = 0; k < num_class; ++k) {\n"
89  << " t = expf(pred[k] - max_margin);\n"
90  << " norm_const += t;\n"
91  << " pred[k] = t;\n"
92  << " }\n"
93  << " for (int k = 0; k < num_class; ++k) {\n"
94  << " pred[k] /= (float)norm_const;\n"
95  << " }\n"
96  << " return (size_t)num_class;\n"
97  << "}\n";
98  return oss.str();
99 }
100 
101 inline std::string multiclass_ova(const Model& model) {
102  CHECK(model.num_output_group > 1)
103  << "multiclass_ova: model is not a proper multi-class classifier";
104  const int num_class = model.num_output_group;
105  const float alpha = model.param.sigmoid_alpha;
106  CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive";
107  std::ostringstream oss;
108  oss << "static inline size_t pred_transform(float* pred) {\n"
109  << " const float alpha = (float)" << common::ToString(alpha) << ";\n"
110  << " const int num_class = " << num_class << ";\n"
111  << " for (int k = 0; k < num_class; ++k) {\n"
112  << " pred[k] = 1.0f / (1.0f + expf(-alpha * pred[k]));\n"
113  << " }\n"
114  << " return (size_t)num_class;\n"
115  << "}\n";
116  return oss.str();
117 }
118 
119 } // namespace native
120 } // namespace pred_transform
121 } // namespace compiler
122 } // namespace treelite
Some useful utilities.