25 const std::vector<std::string> exponential_objectives{
26 "count:poisson",
"reg:gamma",
"reg:tweedie",
"survival:cox",
"survival:aft" 30 void SetPredTransform(
const std::string& objective_name, ModelParam* param) {
31 if (objective_name ==
"multi:softmax") {
32 SetPredTransformString(
"max_index", param);
33 }
else if (objective_name ==
"multi:softprob") {
34 SetPredTransformString(
"softmax", param);
35 }
else if (objective_name ==
"reg:logistic" || objective_name ==
"binary:logistic") {
36 SetPredTransformString(
"sigmoid", param);
37 param->sigmoid_alpha = 1.0f;
38 }
else if (std::find(exponential_objectives.cbegin(), exponential_objectives.cend(),
39 objective_name) != exponential_objectives.cend()) {
40 SetPredTransformString(
"exponential", param);
41 }
else if (objective_name ==
"binary:hinge") {
42 SetPredTransformString(
"hinge", param);
43 }
else if (objective_name ==
"reg:squarederror" || objective_name ==
"reg:linear" 44 || objective_name ==
"reg:squaredlogerror" 45 || objective_name ==
"reg:pseudohubererror" 46 || objective_name ==
"binary:logitraw" 47 || objective_name ==
"rank:pairwise" 48 || objective_name ==
"rank:ndcg" 49 || objective_name ==
"rank:map") {
50 SetPredTransformString(
"identity", param);
52 TREELITE_LOG(FATAL) <<
"Unrecognized XGBoost objective: " << objective_name;
57 void TransformGlobalBiasToMargin(ModelParam* param) {
58 std::string bias_transform{param->pred_transform};
59 if (bias_transform ==
"sigmoid") {
60 param->global_bias = ProbToMargin::Sigmoid(param->global_bias);
61 }
else if (bias_transform ==
"exponential") {
62 param->global_bias = ProbToMargin::Exponential(param->global_bias);
model structure for tree ensemble
logging facility for Treelite
Helper functions for loading XGBoost models.
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function