11 #include <unordered_map> 14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)} 20 using PredTransformFuncGenerator
21 = std::vector<std::string> (*)(
const Model&, bool);
23 std::vector<std::string>
24 identity(
const Model& model,
bool batch) {
26 return {
"return ndata;"};
32 std::vector<std::string>
33 identity_multiclass(
const Model& model,
bool batch) {
34 CHECK(model.num_output_group > 1)
35 <<
"identity_multiclass: model is not a proper multi-class classifier";
38 return {std::string(
"return ndata * ") + std::to_string(num_class) +
";"};
40 return {std::string(
"return ") + std::to_string(num_class) +
";"};
44 std::vector<std::string>
45 sigmoid(
const Model& model,
bool batch) {
46 const float alpha = model.param.sigmoid_alpha;
47 CHECK_GT(alpha, 0.0f) <<
"sigmoid: alpha must be strictly positive";
54 "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
55 " default(none) firstprivate(alpha, ndata) shared(pred) private(i)",
56 "for (i = 0; i < ndata; ++i) {",
57 " pred[i] = 1.0f / (1 + expf(-alpha * pred[i]));",
63 "const float alpha = (float)")
65 "pred[0] = 1.0f / (1 + expf(-alpha * pred[0]));",
70 std::vector<std::string>
71 exponential(
const Model& model,
bool batch) {
75 "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
76 " default(none) firstprivate(ndata) shared(pred) private(i)",
77 "for (i = 0; i < ndata; ++i) {",
78 " pred[i] = expf(pred[i]);",
82 return {
"pred[0] = expf(pred[0]);",
"return 1;"};
86 std::vector<std::string>
87 logarithm_one_plus_exp(
const Model& model,
bool batch) {
91 "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
92 " default(none) firstprivate(ndata) shared(pred) private(i)",
93 "for (i = 0; i < ndata; ++i) {",
94 " pred[i] = logf(1.0f + expf(pred[i]));",
98 return {
"pred[0] = logf(1.0f + expf(pred[0]));",
"return 1;"};
102 std::vector<std::string>
103 max_index(
const Model& model,
bool batch) {
104 CHECK(model.num_output_group > 1)
105 <<
"max_index: model is not a proper multi-class classifier";
106 const int num_class = model.num_output_group;
111 "const int num_class = ") + std::to_string(num_class) +
";",
114 "const float* margin_;",
117 "tmp = (float*)malloc(ndata * sizeof(float));",
118 "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
119 " default(none) firstprivate(num_class, ndata) \\",
120 " private(max_index, max_margin, margin_, i) \\",
121 " shared(pred, tmp)",
122 "for (i = 0; i < ndata; ++i) {",
123 " margin_ = &pred[i * num_class];",
125 " max_margin = margin_[0];",
126 " for (int k = 1; k < num_class; ++k) {",
127 " if (margin_[k] > max_margin) {",
128 " max_margin = margin_[k];",
132 " tmp[i] = (float)max_index;",
134 "memcpy(pred, tmp, ndata * sizeof(float));",
140 "const int num_class = ") + std::to_string(num_class) +
";",
141 "int max_index = 0;",
142 "float max_margin = pred[0];",
143 "for (int k = 1; k < num_class; ++k) {",
144 " if (pred[k] > max_margin) {",
145 " max_margin = pred[k];",
149 "pred[0] = (float)max_index;",
154 std::vector<std::string>
155 softmax(
const Model& model,
bool batch) {
156 CHECK(model.num_output_group > 1)
157 <<
"softmax: model is not a proper multi-class classifier";
158 const int num_class = model.num_output_group;
163 "const int num_class = ") + std::to_string(num_class) +
";",
165 "double norm_const;",
166 "const float* margin_;",
171 "tmp = (float*)malloc(ndata * num_class * sizeof(float));",
172 "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
173 " default(none) firstprivate(num_class, ndata) \\",
174 " private(max_margin, norm_const, margin_, out_pred_, i, t) \\",
175 " shared(pred, tmp)",
176 "for (i = 0; i < ndata; ++i) {",
177 " margin_ = &pred[i * num_class];",
178 " out_pred_ = &tmp[i * num_class];",
179 " max_margin = margin_[0];",
180 " norm_const = 0.0;",
181 " for (int k = 1; k < num_class; ++k) {",
182 " if (margin_[k] > max_margin) {",
183 " max_margin = margin_[k];",
186 " for (int k = 0; k < num_class; ++k) {",
187 " t = expf(margin_[k] - max_margin);",
189 " out_pred_[k] = t;",
191 " for (int k = 0; k < num_class; ++k) {",
192 " out_pred_[k] /= (float)norm_const;",
195 "memcpy(pred, tmp, ndata * num_class * sizeof(float));",
197 "return ndata * num_class;"};
201 "const int num_class = ") + std::to_string(num_class) +
";",
202 "float max_margin = pred[0];",
203 "double norm_const = 0.0;",
205 "for (int k = 1; k < num_class; ++k) {",
206 " if (pred[k] > max_margin) {",
207 " max_margin = pred[k];",
210 "for (int k = 0; k < num_class; ++k) {",
211 " t = expf(pred[k] - max_margin);",
215 "for (int k = 0; k < num_class; ++k) {",
216 " pred[k] /= (float)norm_const;",
218 "return num_class;"};
222 std::vector<std::string>
223 multiclass_ova(
const Model& model,
bool batch) {
224 CHECK(model.num_output_group > 1)
225 <<
"multiclass_ova: model is not a proper multi-class classifier";
226 const int num_class = model.num_output_group;
227 const float alpha = model.param.sigmoid_alpha;
228 CHECK_GT(alpha, 0.0f) <<
"multiclass_ova: alpha must be strictly positive";
233 "const float alpha = (float)")
236 "const int num_class = ") + std::to_string(num_class) +
";",
239 "#pragma omp parallel for schedule(static) num_threads(nthread) \\",
240 " default(none) firstprivate(alpha, num_class, ndata) \\",
241 " private(pred_, i) shared(pred)",
242 "for (i = 0; i < ndata; ++i) {",
243 " pred_ = &pred[i * num_class];" 244 " for (int k = 0; k < num_class; ++k) {",
245 " pred_[k] = 1.0f / (1 + expf(-alpha * pred_[k]));",
248 "return ndata * num_class;"};
252 "const float alpha = (float)")
255 "const int num_class = ") + std::to_string(num_class) +
";",
256 "for (int k = 0; k < num_class; ++k) {",
257 " pred[k] = 1.0f / (1 + expf(-alpha * pred[k]));",
259 "return num_class;"};
263 const std::unordered_map<std::string, PredTransformFuncGenerator>
264 pred_transform_db = {
265 PRED_TRANSFORM_FUNC(identity),
266 PRED_TRANSFORM_FUNC(sigmoid),
267 PRED_TRANSFORM_FUNC(exponential),
268 PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
287 const std::unordered_map<std::string, PredTransformFuncGenerator>
288 pred_transform_multiclass_db = {
289 PRED_TRANSFORM_FUNC(identity_multiclass),
290 PRED_TRANSFORM_FUNC(max_index),
291 PRED_TRANSFORM_FUNC(softmax),
292 PRED_TRANSFORM_FUNC(multiclass_ova)
316 std::vector<std::string>
317 treelite::compiler::PredTransformFunction(
const Model& model,
bool batch) {
318 if (model.num_output_group > 1) {
319 auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
320 if (it == pred_transform_multiclass_db.end()) {
321 std::ostringstream oss;
322 for (
const auto& e : pred_transform_multiclass_db) {
323 oss <<
"'" << e.first <<
"', ";
325 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 326 <<
"For multi-class classification, you should set " 327 <<
"`pred_transform` to one of the following: " 328 <<
"{ " << oss.str() <<
" }";
330 return (it->second)(model, batch);
332 auto it = pred_transform_db.find(model.param.pred_transform);
333 if (it == pred_transform_db.end()) {
334 std::ostringstream oss;
335 for (
const auto& e : pred_transform_db) {
336 oss <<
"'" << e.first <<
"', ";
338 LOG(FATAL) <<
"Invalid argument given for `pred_transform` parameter. " 339 <<
"For any task that is NOT multi-class classification, you " 340 <<
"should set `pred_transform` to one of the following: " 341 <<
"{ " << oss.str() <<
" }";
343 return (it->second)(model, batch);
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
plain code block containing one or more lines of code
thin wrapper for tree ensemble model
std::string ToString(T value)
obtain a string representation of primitive type using ostringstream
Building blocks for semantic model of tree prediction code.