Treelite
xgboost_util.cc
Go to the documentation of this file.
1 
8 #include <treelite/tree.h>
9 #include <treelite/logging.h>
10 #include <cstring>
11 #include "xgboost/xgboost.h"
12 
13 namespace {
14 
15 inline void SetPredTransformString(const char* value, treelite::ModelParam* param) {
16  std::strncpy(param->pred_transform, value, sizeof(param->pred_transform));
17 }
18 
19 } // anonymous namespace
20 
21 namespace treelite {
22 namespace details {
23 namespace xgboost {
24 
25 const std::vector<std::string> exponential_objectives{
26  "count:poisson", "reg:gamma", "reg:tweedie", "survival:cox", "survival:aft"
27 };
28 
29 // set correct prediction transform function, depending on objective function
30 void SetPredTransform(const std::string& objective_name, ModelParam* param) {
31  if (objective_name == "multi:softmax") {
32  SetPredTransformString("max_index", param);
33  } else if (objective_name == "multi:softprob") {
34  SetPredTransformString("softmax", param);
35  } else if (objective_name == "reg:logistic" || objective_name == "binary:logistic") {
36  SetPredTransformString("sigmoid", param);
37  param->sigmoid_alpha = 1.0f;
38  } else if (std::find(exponential_objectives.cbegin(), exponential_objectives.cend(),
39  objective_name) != exponential_objectives.cend()) {
40  SetPredTransformString("exponential", param);
41  } else if (objective_name == "binary:hinge") {
42  SetPredTransformString("hinge", param);
43  } else if (objective_name == "reg:squarederror" || objective_name == "reg:linear"
44  || objective_name == "reg:squaredlogerror"
45  || objective_name == "reg:pseudohubererror"
46  || objective_name == "binary:logitraw"
47  || objective_name == "rank:pairwise"
48  || objective_name == "rank:ndcg"
49  || objective_name == "rank:map") {
50  SetPredTransformString("identity", param);
51  } else {
52  TREELITE_LOG(FATAL) << "Unrecognized XGBoost objective: " << objective_name;
53  }
54 }
55 
56 // Transform the global bias parameter from probability into margin score
57 void TransformGlobalBiasToMargin(ModelParam* param) {
58  std::string bias_transform{param->pred_transform};
59  if (bias_transform == "sigmoid") {
60  param->global_bias = ProbToMargin::Sigmoid(param->global_bias);
61  } else if (bias_transform == "exponential") {
62  param->global_bias = ProbToMargin::Exponential(param->global_bias);
63  }
64 }
65 
66 } // namespace xgboost
67 } // namespace details
68 } // namespace treelite
model structure for tree ensemble
logging facility for Treelite
Helper functions for loading XGBoost models.
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:585