Treelite
pred_transform.cc
1 
8 #include "./pred_transform.h"
9 #include <treelite/gtil.h>
10 #include <treelite/tree.h>
11 #include <dmlc/logging.h>
12 #include <string>
13 #include <unordered_map>
14 #include <cmath>
15 #include <cstddef>
16 
17 namespace treelite {
18 namespace gtil {
19 namespace pred_transform {
20 
21 std::size_t identity(const treelite::Model&, const float* in, float* out) {
22  *out = *in;
23  return 1;
24 }
25 
26 std::size_t signed_square(const treelite::Model&, const float* in, float* out) {
27  const float margin = *in;
28  *out = std::copysign(margin * margin, margin);
29  return 1;
30 }
31 
32 std::size_t hinge(const treelite::Model&, const float* in, float* out) {
33  *out = (*in > 0 ? 1.0f : 0.0f);
34  return 1;
35 }
36 
37 std::size_t sigmoid(const treelite::Model& model, const float* in, float* out) {
38  const float alpha = model.param.sigmoid_alpha;
39  CHECK(alpha > 0.0f) << "sigmoid: alpha must be strictly positive";
40  *out = 1.0f / (1.0f + std::exp(-alpha * *in));
41  return 1;
42 }
43 
44 std::size_t exponential(const treelite::Model&, const float* in, float* out) {
45  *out = std::exp(*in);
46  return 1;
47 }
48 
49 std::size_t logarithm_one_plus_exp(const treelite::Model&, const float* in, float* out) {
50  *out = std::log1p(std::exp(*in));
51  return 1;
52 }
53 
54 std::size_t identity_multiclass(const treelite::Model& model, const float* in, float* out) {
55  auto num_class = static_cast<std::size_t>(model.task_param.num_class);
56  CHECK(num_class > 1) << "model must be a multi-class classifier";
57  for (std::size_t i = 0; i < num_class; ++i) {
58  out[i] = in[i];
59  }
60  return num_class;
61 }
62 
63 std::size_t max_index(const treelite::Model& model, const float* in, float* out) {
64  auto num_class = static_cast<std::size_t>(model.task_param.num_class);
65  CHECK(num_class > 1) << "model must be a multi-class classifier";
66  std::size_t max_index = 0;
67  float max_margin = in[0];
68  for (std::size_t i = 1; i < num_class; ++i) {
69  if (in[i] > max_margin) {
70  max_margin = in[i];
71  max_index = i;
72  }
73  }
74  out[0] = static_cast<float>(max_index);
75  return 1;
76 }
77 
78 std::size_t softmax(const treelite::Model& model, const float* in, float* out) {
79  auto num_class = static_cast<std::size_t>(model.task_param.num_class);
80  CHECK(num_class > 1) << "model must be a multi-class classifier";
81  float max_margin = in[0];
82  double norm_const = 0.0;
83  float t;
84  for (std::size_t i = 1; i < num_class; ++i) {
85  if (in[i] > max_margin) {
86  max_margin = in[i];
87  }
88  }
89  for (std::size_t i = 0; i < num_class; ++i) {
90  t = std::exp(in[i] - max_margin);
91  norm_const += t;
92  out[i] = t;
93  }
94  for (std::size_t i = 0; i < num_class; ++i) {
95  out[i] /= static_cast<float>(norm_const);
96  }
97  return num_class;
98 }
99 
100 std::size_t multiclass_ova(const treelite::Model& model, const float* in, float* out) {
101  auto num_class = static_cast<std::size_t>(model.task_param.num_class);
102  CHECK(num_class > 1) << "model must be a multi-class classifier";
103  const float alpha = model.param.sigmoid_alpha;
104  CHECK(alpha > 0.0f) << "multiclass_ova: alpha must be strictly positive";
105  for (std::size_t i = 0; i < num_class; ++i) {
106  out[i] = 1.0f / (1.0f + std::exp(-alpha * in[i]));
107  }
108  return num_class;
109 }
110 
111 } // namespace pred_transform
112 
113 const std::unordered_map<std::string, PredTransformFuncType> pred_transform_func{
114  {"identity", pred_transform::identity},
115  {"signed_square", pred_transform::signed_square},
116  {"hinge", pred_transform::hinge},
117  {"sigmoid", pred_transform::sigmoid},
118  {"exponential", pred_transform::exponential},
119  {"logarithm_one_plus_exp", pred_transform::logarithm_one_plus_exp},
120  {"identity_multiclass", pred_transform::identity_multiclass},
121  {"max_index", pred_transform::max_index},
122  {"softmax", pred_transform::softmax},
123  {"multiclass_ova", pred_transform::multiclass_ova}
124 };
125 
126 PredTransformFuncType LookupPredTransform(const std::string& name) {
127  return pred_transform_func.at(name);
128 }
129 
130 } // namespace gtil
131 } // namespace treelite
ModelParam param
extra parameters
Definition: tree.h:681
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:599
model structure for tree ensemble
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
TaskParameter task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:679
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:183
thin wrapper for tree ensemble model
Definition: tree.h:632