treelite
pred_transform.cc
Go to the documentation of this file.
1 
8 #include <treelite/tree.h>
9 #include <string>
10 #include <unordered_map>
11 #include "pred_transform.h"
12 
13 #include "./native/pred_transform.h"
14 #include "./java/pred_transform.h"
15 
16 #define PRED_TRANSFORM_FUNC(name) {#name, &(name)}
17 
18 namespace {
19 
20 using Model = treelite::Model;
21 using PredTransformFuncGenerator
22  = std::string (*)(const std::string&, const Model&);
23 
24 /* boilerplate */
25 #define TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(FUNC_NAME) \
26 std::string \
27 FUNC_NAME(const std::string& backend, const Model& model) { \
28  if (backend == "native") { \
29  return treelite::compiler::pred_transform::native::FUNC_NAME(model); \
30  } else if (backend == "java") { \
31  return treelite::compiler::pred_transform::java::FUNC_NAME(model); \
32  } else { \
33  LOG(FATAL) << "Unrecognized backend: " << backend; \
34  } \
35 }
36 
37 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
38 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
39 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
40 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
41 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity_multiclass)
42 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(max_index)
43 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(softmax)
44 TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
45 
46 const std::unordered_map<std::string, PredTransformFuncGenerator>
47 pred_transform_db = {
48  PRED_TRANSFORM_FUNC(identity),
49  PRED_TRANSFORM_FUNC(sigmoid),
50  PRED_TRANSFORM_FUNC(exponential),
51  PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
52 };
69 // prediction transform function for *multi-class classifiers* only
70 const std::unordered_map<std::string, PredTransformFuncGenerator>
71 pred_transform_multiclass_db = {
72  PRED_TRANSFORM_FUNC(identity_multiclass),
73  PRED_TRANSFORM_FUNC(max_index),
74  PRED_TRANSFORM_FUNC(softmax),
75  PRED_TRANSFORM_FUNC(multiclass_ova)
76 };
96 } // namespace anonymous
97 
98 std::string
99 treelite::compiler::PredTransformFunction(const std::string& backend,
100  const Model& model) {
101  if (model.num_output_group > 1) { // multi-class classification
102  auto it = pred_transform_multiclass_db.find(model.param.pred_transform);
103  if (it == pred_transform_multiclass_db.end()) {
104  std::ostringstream oss;
105  for (const auto& e : pred_transform_multiclass_db) {
106  oss << "'" << e.first << "', ";
107  }
108  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
109  << "For multi-class classification, you should set "
110  << "`pred_transform` to one of the following: "
111  << "{ " << oss.str() << " }";
112  }
113  return (it->second)(backend, model);
114  } else {
115  auto it = pred_transform_db.find(model.param.pred_transform);
116  if (it == pred_transform_db.end()) {
117  std::ostringstream oss;
118  for (const auto& e : pred_transform_db) {
119  oss << "'" << e.first << "', ";
120  }
121  LOG(FATAL) << "Invalid argument given for `pred_transform` parameter. "
122  << "For any task that is NOT multi-class classification, you "
123  << "should set `pred_transform` to one of the following: "
124  << "{ " << oss.str() << " }";
125  }
126  return (it->second)(backend, model);
127  }
128 }
thin wrapper for tree ensemble model
Definition: tree.h:351
model structure for tree