treelite
pred_transform.h
1 #include <treelite/common.h>
2 
3 namespace treelite {
4 namespace compiler {
5 namespace pred_transform {
6 namespace java {
7 
8 inline std::string identity(const Model& model) {
9  return
10  " private static float pred_transform(float margin) {\n"
11  " return margin;\n"
12  " }\n";
13 }
14 
15 inline std::string sigmoid(const Model& model) {
16  const float alpha = model.param.sigmoid_alpha;
17  CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive";
18  std::ostringstream oss;
19  oss << " private static float pred_transform(float margin) {\n"
20  << " final double alpha = " << common::ToString(alpha) << ";\n"
21  << " return (float)(1.0 / (1.0 + Math.exp(-alpha * margin)));\n"
22  << " }\n";
23  return oss.str();
24 }
25 
26 inline std::string exponential(const Model& model) {
27  return
28  " private static float pred_transform(float margin) {\n"
29  " return (float)Math.exp(margin);\n"
30  " }\n";
31 }
32 
33 inline std::string logarithm_one_plus_exp(const Model& model) {
34  return
35  " private static float pred_transform(float margin) {\n"
36  " return (float)Math.log1p(Math.exp(margin));\n"
37  " }\n";
38 }
39 
40 inline std::string identity_multiclass(const Model& model) {
41  CHECK(model.num_output_group > 1)
42  << "identity_multiclass: model is not a proper multi-class classifier";
43  const int num_class = model.num_output_group;
44  std::ostringstream oss;
45  oss << " private static long pred_transform(float[] pred) {\n"
46  << " final long num_class = " << num_class << ";\n"
47  << " return num_class;\n"
48  << " }\n";
49  return oss.str();
50 }
51 
52 inline std::string max_index(const Model& model) {
53  CHECK(model.num_output_group > 1)
54  << "max_index: model is not a proper multi-class classifier";
55  const int num_class = model.num_output_group;
56  std::ostringstream oss;
57  oss << " private static long pred_transform(float[] pred) {\n"
58  << " final int num_class = " << num_class << ";\n"
59  << " int max_index = 0;\n"
60  << " float max_margin = pred[0];\n"
61  << " for (int k = 1; k < num_class; ++k) {\n"
62  << " if (pred[k] > max_margin) {\n"
63  << " max_margin = pred[k];\n"
64  << " max_index = k;\n"
65  << " }\n"
66  << " }\n"
67  << " pred[0] = (float)max_index;\n"
68  << " return 1;\n"
69  << " }\n";
70  return oss.str();
71 }
72 
73 inline std::string softmax(const Model& model) {
74  CHECK(model.num_output_group > 1)
75  << "softmax: model is not a proper multi-class classifier";
76  const int num_class = model.num_output_group;
77  std::ostringstream oss;
78  oss << " private static long pred_transform(float[] pred) {\n"
79  << " final int num_class = " << num_class << ";\n"
80  << " float max_margin = pred[0];\n"
81  << " double norm_const = 0.0;\n"
82  << " double t;\n"
83  << " for (int k = 1; k < num_class; ++k) {\n"
84  << " if (pred[k] > max_margin) {\n"
85  << " max_margin = pred[k];\n"
86  << " }\n"
87  << " }\n"
88  << " for (int k = 0; k < num_class; ++k) {\n"
89  << " t = Math.exp(pred[k] - max_margin);\n"
90  << " norm_const += t;\n"
91  << " pred[k] = (float)t;\n"
92  << " }\n"
93  << " for (int k = 0; k < num_class; ++k) {\n"
94  << " pred[k] /= (float)norm_const;\n"
95  << " }\n"
96  << " return (long)num_class;\n"
97  << " }\n";
98  return oss.str();
99 }
100 
101 inline std::string multiclass_ova(const Model& model) {
102  CHECK(model.num_output_group > 1)
103  << "multiclass_ova: model is not a proper multi-class classifier";
104  const int num_class = model.num_output_group;
105  const float alpha = model.param.sigmoid_alpha;
106  CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive";
107  std::ostringstream oss;
108  oss << " private static long pred_transform(float[] pred) {\n"
109  << " final float alpha = (float)" << common::ToString(alpha) << ";\n"
110  << " final int num_class = " << num_class << ";\n"
111  << " for (int k = 0; k < num_class; ++k) {\n"
112  << " pred[k] = (float)(1.0 / (1.0 + Math.exp(-alpha * pred[k])));\n"
113  << " }\n"
114  << " return (long)num_class;\n"
115  << " }\n";
116  return oss.str();
117 }
118 
119 } // namespace java
120 } // namespace pred_transform
121 } // namespace compiler
122 } // namespace treelite
Some useful utilities.