treelite
pred_transform.h
1 
8 #ifndef TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
9 #define TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_
10 
11 #include <treelite/common.h>
12 #include <fmt/format.h>
13 #include <string>
14 
15 using namespace fmt::literals;
16 
17 namespace treelite {
18 namespace compiler {
19 namespace native {
20 namespace pred_transform {
21 
22 inline std::string identity(const Model& model) {
23  return fmt::format(
24 R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{
25  return margin;
26 }})TREELITETEMPLATE");
27 }
28 
29 inline std::string sigmoid(const Model& model) {
30  const float alpha = model.param.sigmoid_alpha;
31  CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive";
32  return fmt::format(
33 R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{
34  const float alpha = (float){alpha};
35  return 1.0f / (1 + expf(-alpha * margin));
36 }})TREELITETEMPLATE",
37  "alpha"_a = alpha);
38 }
39 
40 inline std::string exponential(const Model& model) {
41  return fmt::format(
42 R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{
43  return expf(margin);
44 }})TREELITETEMPLATE");
45 }
46 
47 inline std::string logarithm_one_plus_exp(const Model& model) {
48  return fmt::format(
49 R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{
50  return log1pf(expf(margin));
51 }})TREELITETEMPLATE");
52 }
53 
54 inline std::string identity_multiclass(const Model& model) {
55  CHECK(model.num_output_group > 1)
56  << "identity_multiclass: model is not a proper multi-class classifier";
57  return fmt::format(
58 R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{
59  return {num_class};
60 }})TREELITETEMPLATE",
61  "num_class"_a = model.num_output_group);
62 }
63 
64 inline std::string max_index(const Model& model) {
65  CHECK(model.num_output_group > 1)
66  << "max_index: model is not a proper multi-class classifier";
67  return fmt::format(
68 R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{
69  const int num_class = {num_class};
70  int max_index = 0;
71  float max_margin = pred[0];
72  for (int k = 1; k < num_class; ++k) {{
73  if (pred[k] > max_margin) {{
74  max_margin = pred[k];
75  max_index = k;
76  }}
77  }}
78  pred[0] = (float)max_index;
79  return 1;
80 }})TREELITETEMPLATE",
81  "num_class"_a = model.num_output_group);
82 }
83 
84 inline std::string softmax(const Model& model) {
85  CHECK(model.num_output_group > 1)
86  << "softmax: model is not a proper multi-class classifier";
87  return fmt::format(
88 R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{
89  const int num_class = {num_class};
90  float max_margin = pred[0];
91  double norm_const = 0.0;
92  float t;
93  for (int k = 1; k < num_class; ++k) {{
94  if (pred[k] > max_margin) {{
95  max_margin = pred[k];
96  }}
97  }}
98  for (int k = 0; k < num_class; ++k) {{
99  t = expf(pred[k] - max_margin);
100  norm_const += t;
101  pred[k] = t;
102  }}
103  for (int k = 0; k < num_class; ++k) {{
104  pred[k] /= (float)norm_const;
105  }}
106  return (size_t)num_class;
107 }})TREELITETEMPLATE",
108  "num_class"_a = model.num_output_group);
109 }
110 
111 inline std::string multiclass_ova(const Model& model) {
112  CHECK(model.num_output_group > 1)
113  << "multiclass_ova: model is not a proper multi-class classifier";
114  const int num_class = model.num_output_group;
115  const float alpha = model.param.sigmoid_alpha;
116  CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive";
117  return fmt::format(
118 R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{
119  const float alpha = (float){alpha};
120  const int num_class = {num_class};
121  for (int k = 0; k < num_class; ++k) {{
122  pred[k] = 1.0f / (1.0f + expf(-alpha * pred[k]));
123  }}
124  return (size_t)num_class;
125 }})TREELITETEMPLATE",
126  "num_class"_a = model.num_output_group, "alpha"_a = alpha);
127 }
128 
129 } // namespace pred_transform
130 } // namespace native
131 } // namespace compiler
132 } // namespace treelite
133 #endif // TREELITE_COMPILER_NATIVE_PRED_TRANSFORM_H_