treelite
pred_transform.cc
Go to the documentation of this file.
1 
8 #include <treelite/tree.h>
9 #include <string>
10 #include <unordered_map>
11 #include "pred_transform.h"
12 
13 #include "./native/pred_transform.h"
14 
15 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)}
16 
17 namespace {
18 
19 using Model = treelite::Model;
20 using PredTransformFuncGenerator
21  = std::string (*)(const std::string&, const Model&);
22 
23 /* boilerplate */
24 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \
25 std::string \
26 FUNC_NAME(const std::string& backend, const Model& model) { \
27  if (backend == "native") { \
28  return treelite::compiler::native::pred_transform::FUNC_NAME(model); \
29  } else { \
30  LOG(FATAL) << "Unrecognized backend: " << backend; \
31  return std::string(); \
32  } \
33 }
34 
35 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
36 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
43 
44 const std::unordered_map<std::string, PredTransformFuncGenerator>
45 pred_transform_db = {
46  PRED_TRANSFORM_FUNC(identity),
47  PRED_TRANSFORM_FUNC(sigmoid),
48  PRED_TRANSFORM_FUNC(exponential),
49  PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
50 };
67 // prediction transform function for *multi-class classifiers* only
68 const std::unordered_map<std::string, PredTransformFuncGenerator>
69 pred_transform_multiclass_db = {
70  PRED_TRANSFORM_FUNC(identity_multiclass),
71  PRED_TRANSFORM_FUNC(max_index),
72  PRED_TRANSFORM_FUNC(softmax),
73  PRED_TRANSFORM_FUNC(multiclass_ova)
74 };
94 } // anonymous namespace
95 
96 std::string
97 treelite::compiler::PredTransformFunction(const std::string& backend,
98  const Model& model) {
99  if (model.num_output_group > 1) { // multi-class classification
100  auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
101  if (it == pred_transform_multiclass_db.end()) {
102  std::ostringstream oss;
103  for (const auto& e : pred_transform_multiclass_db) {
104  oss << "'" << e.first << "', ";
105  }
106  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
107  << "For multi-class classification, you should set "
108  << "`pred_transform` to one of the following: "
109  << "{ " << oss.str() << " }";
110  }
111  return (it->second)(backend, model);
112  } else {
113  auto it = pred_transform_db.find(model.param.pred_transform);
114  if (it == pred_transform_db.end()) {
115  std::ostringstream oss;
116  for (const auto& e : pred_transform_db) {
117  oss << "'" << e.first << "', ";
118  }
119  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
120  << "For any task that is NOT multi-class classification, you "
121  << "should set `pred_transform` to one of the following: "
122  << "{ " << oss.str() << " }";
123  }
124  return (it->second)(backend, model);
125  }
126 }
thin wrapper for tree ensemble model
Definition: tree.h:415
model structure for tree