Treelite
pred_transform.cc
Go to the documentation of this file.
1 
8 #include "pred_transform.h"
9 #include <string>
10 #include <unordered_map>
11 
12 #include "./native/pred_transform.h"
13 
14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)}
15 
16 namespace {
17 
18 using Model = treelite::Model;
19 using PredTransformFuncGenerator
20  = std::string (*)(const std::string&, const Model&);
21 
22 /* boilerplate */
23 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \
24 std::string \
25 FUNC_NAME(const std::string& backend, const Model& model) { \
26  if (backend == "native") { \
27  return treelite::compiler::native::pred_transform::FUNC_NAME(model); \
28  } else { \
29  LOG(FATAL) << "Unrecognized backend: " << backend; \
30  return std::string(); \
31  } \
32 }
33 
34 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
42 
43 const std::unordered_map<std::string, PredTransformFuncGenerator>
44 pred_transform_db = {
45  PRED_TRANSFORM_FUNC(identity),
46  PRED_TRANSFORM_FUNC(sigmoid),
47  PRED_TRANSFORM_FUNC(exponential),
48  PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
49 };
66 // prediction transform function for *multi-class classifiers* only
67 const std::unordered_map<std::string, PredTransformFuncGenerator>
68 pred_transform_multiclass_db = {
69  PRED_TRANSFORM_FUNC(identity_multiclass),
70  PRED_TRANSFORM_FUNC(max_index),
71  PRED_TRANSFORM_FUNC(softmax),
72  PRED_TRANSFORM_FUNC(multiclass_ova)
73 };
93 } // anonymous namespace
94 
95 std::string
96 treelite::compiler::PredTransformFunction(const std::string& backend,
97  const Model& model) {
98  if (model.num_output_group > 1) { // multi-class classification
99  auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
100  if (it == pred_transform_multiclass_db.end()) {
101  std::ostringstream oss;
102  for (const auto& e : pred_transform_multiclass_db) {
103  oss << "'" << e.first << "', ";
104  }
105  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
106  << "For multi-class classification, you should set "
107  << "`pred_transform` to one of the following: "
108  << "{ " << oss.str() << " }";
109  }
110  return (it->second)(backend, model);
111  } else {
112  auto it = pred_transform_db.find(model.param.pred_transform);
113  if (it == pred_transform_db.end()) {
114  std::ostringstream oss;
115  for (const auto& e : pred_transform_db) {
116  oss << "'" << e.first << "', ";
117  }
118  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
119  << "For any task that is NOT multi-class classification, you "
120  << "should set `pred_transform` to one of the following: "
121  << "{ " << oss.str() << " }";
122  }
123  return (it->second)(backend, model);
124  }
125 }
thin wrapper for tree ensemble model
Definition: tree.h:409