treelite
pred_transform.cc
Go to the documentation of this file.
1 
8 #include <treelite/semantic.h>
9 #include <treelite/tree.h>
10 #include <string>
11 #include <unordered_map>
12 #include "pred_transform.h"
13 
14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)}
15 
16 namespace {
17 
18 using PlainBlock = treelite::semantic::PlainBlock;
19 using Model = treelite::Model;
20 using PredTransformFuncGenerator
21  = std::vector<std::string> (*)(const Model&, bool);
22 
23 std::vector<std::string>
24 identity(const Model& model, bool batch) {
25  if (batch) {
26  return {"return ndata;"};
27  } else {
28  return {"return 1;"};
29  }
30 }
31 
32 std::vector<std::string>
33 identity_multiclass(const Model& model, bool batch) {
34  CHECK(model.num_output_group > 1)
35  << "identity_multiclass: model is not a proper multi-class classifier";
36  const int num_class = model.num_output_group;
37  if (batch) {
38  return {std::string("return ndata * ") + std::to_string(num_class) + ";"};
39  } else {
40  return {std::string("return ") + std::to_string(num_class) + ";"};
41  }
42 }
43 
44 std::vector<std::string>
45 sigmoid(const Model& model, bool batch) {
46  const float alpha = model.param.sigmoid_alpha;
47  CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive";
48 
49  if (batch) {
50  return {
51  std::string(
52  "const float alpha = (float)") + treelite::common::ToString(alpha) + ";",
53  "int64_t i;",
54  "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
55  " default(none) firstprivate(alpha, ndata) shared(pred) private(i)",
56  "for (i = 0; i < ndata; ++i) {",
57  " pred[i] = 1.0f / (1 + expf(-alpha * pred[i]));",
58  "}",
59  "return ndata;"};
60  } else {
61  return {
62  std::string(
63  "const float alpha = (float)")
64  + treelite::common::ToString(alpha) + ";",
65  "pred[0] = 1.0f / (1 + expf(-alpha * pred[0]));",
66  "return 1;"};
67  }
68 }
69 
70 std::vector<std::string>
71 exponential(const Model& model, bool batch) {
72  if (batch) {
73  return {
74  "int64_t i;",
75  "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
76  " default(none) firstprivate(ndata) shared(pred) private(i)",
77  "for (i = 0; i < ndata; ++i) {",
78  " pred[i] = expf(pred[i]);",
79  "}",
80  "return ndata;"};
81  } else {
82  return {"pred[0] = expf(pred[0]);", "return 1;"};
83  }
84 }
85 
86 std::vector<std::string>
87 logarithm_one_plus_exp(const Model& model, bool batch) {
88  if (batch) {
89  return {
90  "int64_t i;",
91  "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
92  " default(none) firstprivate(ndata) shared(pred) private(i)",
93  "for (i = 0; i < ndata; ++i) {",
94  " pred[i] = logf(1.0f + expf(pred[i]));",
95  "}",
96  "return ndata;"};
97  } else {
98  return {"pred[0] = logf(1.0f + expf(pred[0]));", "return 1;"};
99  }
100 }
101 
102 std::vector<std::string>
103 max_index(const Model& model, bool batch) {
104  CHECK(model.num_output_group > 1)
105  << "max_index: model is not a proper multi-class classifier";
106  const int num_class = model.num_output_group;
107 
108  if (batch) {
109  return {
110  std::string(
111  "const int num_class = ") + std::to_string(num_class) + ";",
112  "int max_index;",
113  "float max_margin;",
114  "const float* margin_;",
115  "float* tmp;",
116  "int64_t i;",
117  "tmp = (float*)malloc(ndata * sizeof(float));",
118  "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
119  " default(none) firstprivate(num_class, ndata) \\",
120  " private(max_index, max_margin, margin_, i) \\",
121  " shared(pred, tmp)",
122  "for (i = 0; i < ndata; ++i) {",
123  " margin_ = &pred[i * num_class];",
124  " max_index = 0;",
125  " max_margin = margin_[0];",
126  " for (int k = 1; k < num_class; ++k) {",
127  " if (margin_[k] > max_margin) {",
128  " max_margin = margin_[k];",
129  " max_index = k;",
130  " }",
131  " }",
132  " tmp[i] = (float)max_index;",
133  "}",
134  "memcpy(pred, tmp, ndata * sizeof(float));",
135  "free(tmp);",
136  "return ndata;"};
137  } else {
138  return {
139  std::string(
140  "const int num_class = ") + std::to_string(num_class) + ";",
141  "int max_index = 0;",
142  "float max_margin = pred[0];",
143  "for (int k = 1; k < num_class; ++k) {",
144  " if (pred[k] > max_margin) {",
145  " max_margin = pred[k];",
146  " max_index = k;",
147  " }",
148  "}",
149  "pred[0] = (float)max_index;",
150  "return 1;"};
151  }
152 }
153 
154 std::vector<std::string>
155 softmax(const Model& model, bool batch) {
156  CHECK(model.num_output_group > 1)
157  << "softmax: model is not a proper multi-class classifier";
158  const int num_class = model.num_output_group;
159 
160  if (batch) {
161  return {
162  std::string(
163  "const int num_class = ") + std::to_string(num_class) + ";",
164  "float max_margin;",
165  "double norm_const;",
166  "const float* margin_;",
167  "float* out_pred_;",
168  "float* tmp;",
169  "float t;",
170  "int64_t i;",
171  "tmp = (float*)malloc(ndata * num_class * sizeof(float));",
172  "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
173  " default(none) firstprivate(num_class, ndata) \\",
174  " private(max_margin, norm_const, margin_, out_pred_, i, t) \\",
175  " shared(pred, tmp)",
176  "for (i = 0; i < ndata; ++i) {",
177  " margin_ = &pred[i * num_class];",
178  " out_pred_ = &tmp[i * num_class];",
179  " max_margin = margin_[0];",
180  " norm_const = 0.0;",
181  " for (int k = 1; k < num_class; ++k) {",
182  " if (margin_[k] > max_margin) {",
183  " max_margin = margin_[k];",
184  " }",
185  " }",
186  " for (int k = 0; k < num_class; ++k) {",
187  " t = expf(margin_[k] - max_margin);",
188  " norm_const += t;",
189  " out_pred_[k] = t;",
190  " }",
191  " for (int k = 0; k < num_class; ++k) {",
192  " out_pred_[k] /= (float)norm_const;",
193  " }",
194  "}",
195  "memcpy(pred, tmp, ndata * num_class * sizeof(float));",
196  "free(tmp);",
197  "return ndata * num_class;"};
198  } else {
199  return {
200  std::string(
201  "const int num_class = ") + std::to_string(num_class) + ";",
202  "float max_margin = pred[0];",
203  "double norm_const = 0.0;",
204  "float t;",
205  "for (int k = 1; k < num_class; ++k) {",
206  " if (pred[k] > max_margin) {",
207  " max_margin = pred[k];",
208  " }",
209  "}",
210  "for (int k = 0; k < num_class; ++k) {",
211  " t = expf(pred[k] - max_margin);",
212  " norm_const += t;",
213  " pred[k] = t;",
214  "}",
215  "for (int k = 0; k < num_class; ++k) {",
216  " pred[k] /= (float)norm_const;",
217  "}",
218  "return num_class;"};
219  }
220 }
221 
222 std::vector<std::string>
223 multiclass_ova(const Model& model, bool batch) {
224  CHECK(model.num_output_group > 1)
225  << "multiclass_ova: model is not a proper multi-class classifier";
226  const int num_class = model.num_output_group;
227  const float alpha = model.param.sigmoid_alpha;
228  CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive";
229 
230  if (batch) {
231  return {
232  std::string(
233  "const float alpha = (float)")
234  + treelite::common::ToString(alpha) + ";",
235  std::string(
236  "const int num_class = ") + std::to_string(num_class) + ";",
237  "float* pred_;",
238  "int64_t i;",
239  "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
240  " default(none) firstprivate(alpha, num_class, ndata) \\",
241  " private(pred_, i) shared(pred)",
242  "for (i = 0; i < ndata; ++i) {",
243  " pred_ = &pred[i * num_class];"
244  " for (int k = 0; k < num_class; ++k) {",
245  " pred_[k] = 1.0f / (1 + expf(-alpha * pred_[k]));",
246  " }",
247  "}",
248  "return ndata * num_class;"};
249  } else {
250  return {
251  std::string(
252  "const float alpha = (float)")
253  + treelite::common::ToString(alpha) + ";",
254  std::string(
255  "const int num_class = ") + std::to_string(num_class) + ";",
256  "for (int k = 0; k < num_class; ++k) {",
257  " pred[k] = 1.0f / (1 + expf(-alpha * pred[k]));",
258  "}",
259  "return num_class;"};
260  }
261 }
262 
263 const std::unordered_map<std::string, PredTransformFuncGenerator>
264 pred_transform_db = {
265  PRED_TRANSFORM_FUNC(identity),
266  PRED_TRANSFORM_FUNC(sigmoid),
267  PRED_TRANSFORM_FUNC(exponential),
268  PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
269 };
286 // prediction transform function for *multi-class classifiers* only
287 const std::unordered_map<std::string, PredTransformFuncGenerator>
288 pred_transform_multiclass_db = {
289  PRED_TRANSFORM_FUNC(identity_multiclass),
290  PRED_TRANSFORM_FUNC(max_index),
291  PRED_TRANSFORM_FUNC(softmax),
292  PRED_TRANSFORM_FUNC(multiclass_ova)
293 };
313 } // namespace anonymous
314 
315 
316 std::vector<std::string>
317 treelite::compiler::PredTransformFunction(const Model& model, bool batch) {
318  if (model.num_output_group > 1) { // multi-class classification
319  auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
320  if (it == pred_transform_multiclass_db.end()) {
321  std::ostringstream oss;
322  for (const auto& e : pred_transform_multiclass_db) {
323  oss << "'" << e.first << "', ";
324  }
325  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
326  << "For multi-class classification, you should set "
327  << "`pred_transform` to one of the following: "
328  << "{ " << oss.str() << " }";
329  }
330  return (it->second)(model, batch);
331  } else {
332  auto it = pred_transform_db.find(model.param.pred_transform);
333  if (it == pred_transform_db.end()) {
334  std::ostringstream oss;
335  for (const auto& e : pred_transform_db) {
336  oss << "'" << e.first << "', ";
337  }
338  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
339  << "For any task that is NOT multi-class classification, you "
340  << "should set `pred_transform` to one of the following: "
341  << "{ " << oss.str() << " }";
342  }
343  return (it->second)(model, batch);
344  }
345 }
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
plain code block containing one or more lines of code
Definition: semantic.h:118
thin wrapper for tree ensemble model
Definition: tree.h:351
model structure for tree
std::string ToString(T value)
obtain a string representation of primitive type using ostringstream
Definition: common.h:163
tools to define prediction transform function
Building blocks for semantic model of tree prediction code.