8 #include "./pred_transform.h" 11 #include <dmlc/logging.h> 13 #include <unordered_map> 19 namespace pred_transform {
21 std::size_t identity(
const treelite::Model&,
const float* in,
float* out) {
26 std::size_t signed_square(
const treelite::Model&,
const float* in,
float* out) {
27 const float margin = *in;
28 *out = std::copysign(margin * margin, margin);
32 std::size_t hinge(
const treelite::Model&,
const float* in,
float* out) {
33 *out = (*in > 0 ? 1.0f : 0.0f);
37 std::size_t sigmoid(
const treelite::Model& model,
const float* in,
float* out) {
39 CHECK(alpha > 0.0f) <<
"sigmoid: alpha must be strictly positive";
40 *out = 1.0f / (1.0f + std::exp(-alpha * *in));
44 std::size_t exponential(
const treelite::Model&,
const float* in,
float* out) {
49 std::size_t logarithm_one_plus_exp(
const treelite::Model&,
const float* in,
float* out) {
50 *out = std::log1p(std::exp(*in));
54 std::size_t identity_multiclass(
const treelite::Model& model,
const float* in,
float* out) {
56 CHECK(num_class > 1) <<
"model must be a multi-class classifier";
57 for (std::size_t i = 0; i < num_class; ++i) {
63 std::size_t max_index(
const treelite::Model& model,
const float* in,
float* out) {
65 CHECK(num_class > 1) <<
"model must be a multi-class classifier";
66 std::size_t max_index = 0;
67 float max_margin = in[0];
68 for (std::size_t i = 1; i < num_class; ++i) {
69 if (in[i] > max_margin) {
74 out[0] =
static_cast<float>(max_index);
78 std::size_t softmax(
const treelite::Model& model,
const float* in,
float* out) {
80 CHECK(num_class > 1) <<
"model must be a multi-class classifier";
81 float max_margin = in[0];
82 double norm_const = 0.0;
84 for (std::size_t i = 1; i < num_class; ++i) {
85 if (in[i] > max_margin) {
89 for (std::size_t i = 0; i < num_class; ++i) {
90 t = std::exp(in[i] - max_margin);
94 for (std::size_t i = 0; i < num_class; ++i) {
95 out[i] /=
static_cast<float>(norm_const);
100 std::size_t multiclass_ova(
const treelite::Model& model,
const float* in,
float* out) {
102 CHECK(num_class > 1) <<
"model must be a multi-class classifier";
104 CHECK(alpha > 0.0f) <<
"multiclass_ova: alpha must be strictly positive";
105 for (std::size_t i = 0; i < num_class; ++i) {
106 out[i] = 1.0f / (1.0f + std::exp(-alpha * in[i]));
113 const std::unordered_map<std::string, PredTransformFuncType> pred_transform_func{
114 {
"identity", pred_transform::identity},
115 {
"signed_square", pred_transform::signed_square},
116 {
"hinge", pred_transform::hinge},
117 {
"sigmoid", pred_transform::sigmoid},
118 {
"exponential", pred_transform::exponential},
119 {
"logarithm_one_plus_exp", pred_transform::logarithm_one_plus_exp},
120 {
"identity_multiclass", pred_transform::identity_multiclass},
121 {
"max_index", pred_transform::max_index},
122 {
"softmax", pred_transform::softmax},
123 {
"multiclass_ova", pred_transform::multiclass_ova}
126 PredTransformFuncType LookupPredTransform(
const std::string& name) {
127 return pred_transform_func.at(name);
ModelParam param
extra parameters
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
model structure for tree ensemble
General Tree Inference Library (GTIL), providing a reference implementation for predicting with decis...
TaskParameter task_param
Group of parameters that are specific to the particular task type.
unsigned int num_class
The number of classes in the target label.
thin wrapper for tree ensemble model