Treelite
pred_transform.cc
1 
8 #include "pred_transform.h"
9 #include <string>
10 #include <unordered_map>
11 
12 #include "./native/pred_transform.h"
13 
14 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)}
15 
16 namespace {
17 
18 using Model = treelite::Model;
19 using PredTransformFuncGenerator
20  = std::string (*)(const std::string&, const Model&);
21 
22 /* boilerplate */
23 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \
24 std::string \
25 FUNC_NAME(const std::string& backend, const Model& model) { \
26  if (backend == "native") { \
27  return treelite::compiler::native::pred_transform::FUNC_NAME(model); \
28  } else { \
29  LOG(FATAL) << "Unrecognized backend: " << backend; \
30  return std::string(); \
31  } \
32 }
33 
34 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(signed_square)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(hinge)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
43 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
44 
45 const std::unordered_map<std::string, PredTransformFuncGenerator>
46 pred_transform_db = {
47  PRED_TRANSFORM_FUNC(identity),
48  PRED_TRANSFORM_FUNC(signed_square),
49  PRED_TRANSFORM_FUNC(hinge),
50  PRED_TRANSFORM_FUNC(sigmoid),
51  PRED_TRANSFORM_FUNC(exponential),
52  PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
53 };
76 // prediction transform function for *multi-class classifiers* only
77 const std::unordered_map<std::string, PredTransformFuncGenerator>
78 pred_transform_multiclass_db = {
79  PRED_TRANSFORM_FUNC(identity_multiclass),
80  PRED_TRANSFORM_FUNC(max_index),
81  PRED_TRANSFORM_FUNC(softmax),
82  PRED_TRANSFORM_FUNC(multiclass_ova)
83 };
103 } // anonymous namespace
104 
105 std::string
106 treelite::compiler::PredTransformFunction(const std::string& backend,
107  const Model& model) {
108  ModelParam param = model.param;
109  if (model.task_param.num_class > 1) { // multi-class classification
110  auto it = pred_transform_multiclass_db.find(param.pred_transform);
111  if (it == pred_transform_multiclass_db.end()) {
112  std::ostringstream oss;
113  for (const auto& e : pred_transform_multiclass_db) {
114  oss << "'" << e.first << "', ";
115  }
116  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
117  << "For multi-class classification, you should set "
118  << "`pred_transform` to one of the following: "
119  << "{ " << oss.str() << " }";
120  }
121  return (it->second)(backend, model);
122  } else {
123  auto it = pred_transform_db.find(param.pred_transform);
124  if (it == pred_transform_db.end()) {
125  std::ostringstream oss;
126  for (const auto& e : pred_transform_db) {
127  oss << "'" << e.first << "', ";
128  }
129  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
130  << "For any task that is NOT multi-class classification, you "
131  << "should set `pred_transform` to one of the following: "
132  << "{ " << oss.str() << " }";
133  }
134  return (it->second)(backend, model);
135  }
136 }
thin wrapper for tree ensemble model
Definition: tree.h:632