8 #include "./pred_transform.h" 12 #include <unordered_map> 18 namespace pred_transform {
20 std::size_t identity(
const treelite::Model&,
const float* in,
float* out) {
25 std::size_t signed_square(
const treelite::Model&,
const float* in,
float* out) {
26 const float margin = *in;
27 *out = std::copysign(margin * margin, margin);
31 std::size_t hinge(
const treelite::Model&,
const float* in,
float* out) {
32 *out = (*in > 0 ? 1.0f : 0.0f);
36 std::size_t sigmoid(
const treelite::Model& model,
const float* in,
float* out) {
38 TREELITE_CHECK(alpha > 0.0f) <<
"sigmoid: alpha must be strictly positive";
39 *out = 1.0f / (1.0f + std::exp(-alpha * *in));
43 std::size_t exponential(
const treelite::Model&,
const float* in,
float* out) {
48 std::size_t logarithm_one_plus_exp(
const treelite::Model&,
const float* in,
float* out) {
49 *out = std::log1p(std::exp(*in));
53 std::size_t identity_multiclass(
const treelite::Model& model,
const float* in,
float* out) {
55 TREELITE_CHECK(num_class > 1) <<
"model must be a multi-class classifier";
56 for (std::size_t i = 0; i < num_class; ++i) {
62 std::size_t max_index(
const treelite::Model& model,
const float* in,
float* out) {
64 TREELITE_CHECK(num_class > 1) <<
"model must be a multi-class classifier";
65 std::size_t max_index = 0;
66 float max_margin = in[0];
67 for (std::size_t i = 1; i < num_class; ++i) {
68 if (in[i] > max_margin) {
73 out[0] =
static_cast<float>(max_index);
77 std::size_t softmax(
const treelite::Model& model,
const float* in,
float* out) {
79 TREELITE_CHECK(num_class > 1) <<
"model must be a multi-class classifier";
80 float max_margin = in[0];
81 double norm_const = 0.0;
83 for (std::size_t i = 1; i < num_class; ++i) {
84 if (in[i] > max_margin) {
88 for (std::size_t i = 0; i < num_class; ++i) {
89 t = std::exp(in[i] - max_margin);
93 for (std::size_t i = 0; i < num_class; ++i) {
94 out[i] /=
static_cast<float>(norm_const);
99 std::size_t multiclass_ova(
const treelite::Model& model,
const float* in,
float* out) {
101 TREELITE_CHECK(num_class > 1) <<
"model must be a multi-class classifier";
103 TREELITE_CHECK(alpha > 0.0f) <<
"multiclass_ova: alpha must be strictly positive";
104 for (std::size_t i = 0; i < num_class; ++i) {
105 out[i] = 1.0f / (1.0f + std::exp(-alpha * in[i]));
112 const std::unordered_map<std::string, PredTransformFuncType> pred_transform_func{
113 {
"identity", pred_transform::identity},
114 {
"signed_square", pred_transform::signed_square},
115 {
"hinge", pred_transform::hinge},
116 {
"sigmoid", pred_transform::sigmoid},
117 {
"exponential", pred_transform::exponential},
118 {
"logarithm_one_plus_exp", pred_transform::logarithm_one_plus_exp},
119 {
"identity_multiclass", pred_transform::identity_multiclass},
120 {
"max_index", pred_transform::max_index},
121 {
"softmax", pred_transform::softmax},
122 {
"multiclass_ova", pred_transform::multiclass_ova}
125 PredTransformFuncType LookupPredTransform(
const std::string& name) {
126 return pred_transform_func.at(name);
ModelParam param
extra parameters
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
model structure for tree ensemble
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
TaskParam task_param
Group of parameters that are specific to the particular task type.
thin wrapper for tree ensemble model