Treelite
pred_transform.cc
1 
8 #include "pred_transform.h"
9 #include <string>
10 #include <unordered_map>
11 
12 #include "./native/pred_transform.h"
13 
14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)}
15 
16 namespace {
17 
18 using Model = treelite::Model;
19 using PredTransformFuncGenerator
20  = std::string (*)(const std::string&, const Model&);
21 
22 /* boilerplate */
23 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \
24 std::string \
25 FUNC_NAME(const std::string& backend, const Model& model) { \
26  if (backend == "native") { \
27  return treelite::compiler::native::pred_transform::FUNC_NAME(model); \
28  } else { \
29  TREELITE_LOG(FATAL) << "Unrecognized backend: " << backend; \
30  return std::string(); \
31  } \
32 }
33 
34 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(signed_square)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(hinge)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential_standard_ratio)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
43 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
44 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
45 
46 const std::unordered_map<std::string, PredTransformFuncGenerator>
47 pred_transform_db = {
48  PRED_TRANSFORM_FUNC(identity),
49  PRED_TRANSFORM_FUNC(signed_square),
50  PRED_TRANSFORM_FUNC(hinge),
51  PRED_TRANSFORM_FUNC(sigmoid),
52  PRED_TRANSFORM_FUNC(exponential),
53  PRED_TRANSFORM_FUNC(exponential_standard_ratio),
54  PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
55 };
82 // prediction transform function for *multi-class classifiers* only
83 const std::unordered_map<std::string, PredTransformFuncGenerator>
84 pred_transform_multiclass_db = {
85  PRED_TRANSFORM_FUNC(identity_multiclass),
86  PRED_TRANSFORM_FUNC(max_index),
87  PRED_TRANSFORM_FUNC(softmax),
88  PRED_TRANSFORM_FUNC(multiclass_ova)
89 };
109 } // anonymous namespace
110 
111 std::string
112 treelite::compiler::PredTransformFunction(const std::string& backend,
113  const Model& model) {
114  ModelParam param = model.param;
115  if (model.task_param.num_class > 1) { // multi-class classification
116  auto it = pred_transform_multiclass_db.find(param.pred_transform);
117  if (it == pred_transform_multiclass_db.end()) {
118  std::ostringstream oss;
119  for (const auto& e : pred_transform_multiclass_db) {
120  oss << "'" << e.first << "', ";
121  }
122  TREELITE_LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
123  << "For multi-class classification, you should set "
124  << "`pred_transform` to one of the following: "
125  << "{ " << oss.str() << " }";
126  }
127  return (it->second)(backend, model);
128  } else {
129  auto it = pred_transform_db.find(param.pred_transform);
130  if (it == pred_transform_db.end()) {
131  std::ostringstream oss;
132  for (const auto& e : pred_transform_db) {
133  oss << "'" << e.first << "', ";
134  }
135  TREELITE_LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
136  << "For any task that is NOT multi-class classification, you "
137  << "should set `pred_transform` to one of the following: "
138  << "{ " << oss.str() << " }";
139  }
140  return (it->second)(backend, model);
141  }
142 }
thin wrapper for tree ensemble model
Definition: tree.h:667