12 #include <dmlc/data.h> 13 #include <dmlc/memory_io.h> 26 DMLC_REGISTRY_FILE_TAG(xgboost);
29 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
30 return ParseStream(fi.get());
34 dmlc::MemoryFixedSizeStream fs(const_cast<void*>(buf), len);
35 return ParseStream(&fs);
44 typedef float bst_float;
47 static float Sigmoid(
float global_bias) {
48 return -logf(1.0f / global_bias - 1.0f);
50 static float Exponential(
float global_bias) {
51 return logf(global_bias);
56 class PeekableInputStream {
58 const size_t MAX_PEEK_WINDOW = 1024;
60 explicit PeekableInputStream(dmlc::Stream* fi)
61 : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
63 inline size_t Read(
void* ptr,
size_t size) {
64 const size_t bytes_buffered = BytesBuffered();
65 char* cptr =
static_cast<char*
>(ptr);
66 if (size <= bytes_buffered) {
68 if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
69 std::memcpy(cptr, &buf_[begin_ptr_], size);
72 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
73 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
74 size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
75 begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
79 const size_t bytes_to_read = size - bytes_buffered;
80 if (begin_ptr_ <= end_ptr_) {
81 std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
83 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
84 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
85 bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
87 begin_ptr_ = end_ptr_;
89 + istm_->Read(cptr + bytes_buffered, bytes_to_read);
93 inline size_t PeekRead(
void* ptr,
size_t size) {
94 CHECK_LE(size, MAX_PEEK_WINDOW)
95 <<
"PeekableInputStream allows peeking up to " 96 << MAX_PEEK_WINDOW <<
" bytes";
97 char* cptr =
static_cast<char*
>(ptr);
98 const size_t bytes_buffered = BytesBuffered();
100 if (size > bytes_buffered) {
101 const size_t bytes_to_read = size - bytes_buffered;
102 if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
103 CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
104 <<
"Failed to peek " << size <<
" bytes";
105 end_ptr_ += bytes_to_read;
107 CHECK_EQ(istm_->Read(&buf_[end_ptr_],
108 MAX_PEEK_WINDOW + 1 - end_ptr_)
109 + istm_->Read(&buf_[0],
110 bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
112 <<
"Ill-formed XGBoost model: Failed to peek " << size <<
" bytes";
113 end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
117 if (begin_ptr_ <= end_ptr_) {
118 std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
120 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
121 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
129 std::vector<char> buf_;
130 size_t begin_ptr_, end_ptr_;
132 inline size_t BytesBuffered() {
133 if (begin_ptr_ <= end_ptr_) {
134 return end_ptr_ - begin_ptr_;
136 return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
141 template <
typename T>
142 inline void CONSUME_BYTES(
const T& fi,
size_t size) {
143 static std::vector<char> dummy(500);
144 if (size > dummy.size()) dummy.resize(size);
145 CHECK_EQ(fi->Read(&dummy[0], size), size)
146 <<
"Ill-formed XGBoost model format: cannot read " << size
147 <<
" bytes from the file";
150 struct LearnerModelParam {
151 bst_float base_score;
152 unsigned num_feature;
154 int contain_extra_attrs;
155 int contain_eval_metrics;
156 uint32_t major_version;
157 uint32_t minor_version;
160 static_assert(
sizeof(LearnerModelParam) == 136,
"This is the size defined in XGBoost.");
162 struct GBTreeModelParam {
168 int num_output_group;
169 int size_leaf_vector;
179 int size_leaf_vector;
186 bst_float base_weight;
194 Node() : sindex_(0) {
196 static_assert(
sizeof(Node) == 4 *
sizeof(
int) +
sizeof(Info),
197 "Node: 64 bit align");
199 inline int cleft()
const {
202 inline int cright()
const {
203 return this->cright_;
205 inline int cdefault()
const {
206 return this->default_left() ? this->cleft() : this->cright();
208 inline unsigned split_index()
const {
209 return sindex_ & ((1U << 31) - 1U);
211 inline bool default_left()
const {
212 return (sindex_ >> 31) != 0;
214 inline bool is_leaf()
const {
217 inline bst_float leaf_value()
const {
218 return (this->info_).leaf_value;
220 inline bst_float split_cond()
const {
221 return (this->info_).split_cond;
223 inline int parent()
const {
224 return parent_ & ((1U << 31) - 1);
226 inline bool is_root()
const {
227 return parent_ == -1;
229 inline void set_leaf(bst_float value) {
230 (this->info_).leaf_value = value;
234 inline void set_split(
unsigned split_index,
235 bst_float split_cond,
236 bool default_left =
false) {
237 if (default_left) split_index |= (1U << 31);
238 this->sindex_ = split_index;
239 (this->info_).split_cond = split_cond;
243 friend class XGBTree;
245 bst_float leaf_value;
246 bst_float split_cond;
253 inline bool is_deleted()
const {
254 return sindex_ == std::numeric_limits<unsigned>::max();
256 inline void set_parent(
int pidx,
bool is_left_child =
true) {
257 if (is_left_child) pidx |= (1U << 31);
258 this->parent_ = pidx;
264 std::vector<Node> nodes;
265 std::vector<NodeStat> stats;
267 inline int AllocNode() {
268 int nd = param.num_nodes++;
269 CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
270 <<
"number of nodes in the tree exceed 2^31";
271 nodes.resize(param.num_nodes);
277 inline Node& operator[](
int nid) {
281 inline const Node& operator[](
int nid)
const {
285 inline NodeStat& Stat(
int nid) {
289 inline const NodeStat& Stat(
int nid)
const {
295 nodes[0].set_leaf(0.0f);
296 nodes[0].set_parent(-1);
298 inline void AddChilds(
int nid) {
299 int pleft = this->AllocNode();
300 int pright = this->AllocNode();
301 nodes[nid].cleft_ = pleft;
302 nodes[nid].cright_ = pright;
303 nodes[nodes[nid].cleft() ].set_parent(nid,
true);
304 nodes[nodes[nid].cright()].set_parent(nid,
false);
306 inline void Load(PeekableInputStream* fi) {
307 CHECK_EQ(fi->Read(¶m,
sizeof(TreeParam)),
sizeof(TreeParam))
308 <<
"Ill-formed XGBoost model file: can't read TreeParam";
309 nodes.resize(param.num_nodes);
310 stats.resize(param.num_nodes);
311 CHECK_NE(param.num_nodes, 0)
312 <<
"Ill-formed XGBoost model file: a tree can't be empty";
313 CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes),
sizeof(Node) * nodes.size()),
314 sizeof(Node) * nodes.size())
315 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
316 CHECK_EQ(fi->Read(dmlc::BeginPtr(stats),
sizeof(NodeStat) * stats.size()),
317 sizeof(NodeStat) * stats.size())
318 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
319 if (param.size_leaf_vector != 0) {
321 CHECK_EQ(fi->Read(&len,
sizeof(len)),
sizeof(len))
322 <<
"Ill-formed XGBoost model file";
324 CONSUME_BYTES(fi,
sizeof(bst_float) * len);
327 CHECK_EQ(param.num_roots, 1)
328 <<
"Invalid XGBoost model file: treelite does not support trees " 329 <<
"with multiple roots";
334 std::vector<XGBTree> xgb_trees_;
335 LearnerModelParam mparam_;
336 GBTreeModelParam gbm_param_;
337 std::string name_gbm_;
338 std::string name_obj_;
341 std::unique_ptr<PeekableInputStream> fp(
new PeekableInputStream(fi));
345 if (fp->PeekRead(&header[0], 4) == 4) {
346 CHECK_NE(header,
"bs64")
347 <<
"Ill-formed XGBoost model file: Base64 format no longer supported";
348 if (header ==
"binf") {
349 CONSUME_BYTES(fp, 4);
353 CHECK_EQ(fp->Read(&mparam_,
sizeof(mparam_)),
sizeof(mparam_))
354 <<
"Ill-formed XGBoost model file: corrupted header";
357 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
358 <<
"Ill-formed XGBoost model file: corrupted header";
360 name_obj_.resize(len);
361 CHECK_EQ(fp->Read(&name_obj_[0], len), len)
362 <<
"Ill-formed XGBoost model file: corrupted header";
368 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
369 <<
"Ill-formed XGBoost model file: corrupted header";
370 name_gbm_.resize(len);
372 CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
373 <<
"Ill-formed XGBoost model file: corrupted header";
378 CHECK_EQ(name_gbm_,
"gbtree")
379 <<
"Invalid XGBoost model file: " 380 <<
"Gradient booster must be gbtree type.";
382 CHECK_EQ(fp->Read(&gbm_param_,
sizeof(gbm_param_)),
sizeof(gbm_param_))
383 <<
"Invalid XGBoost model file: corrupted GBTree parameters";
384 for (
int i = 0; i < gbm_param_.num_trees; ++i) {
385 xgb_trees_.emplace_back();
386 xgb_trees_.back().Load(fp.get());
388 CHECK_EQ(gbm_param_.num_roots, 1) <<
"multi-root trees not supported";
392 bool need_transform_to_margin = mparam_.major_version >= 1;
402 std::vector<std::string> exponential_family {
403 "count:poisson",
"reg:gamma",
"reg:tweedie" 405 if (need_transform_to_margin) {
406 if (name_obj_ ==
"reg:logistic" || name_obj_ ==
"binary:logistic") {
408 }
else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_)
409 != exponential_family.cend()) {
415 if (name_obj_ ==
"multi:softmax") {
417 }
else if (name_obj_ ==
"multi:softprob") {
419 }
else if (name_obj_ ==
"reg:logistic" || name_obj_ ==
"binary:logistic") {
422 }
else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_)
423 != exponential_family.cend()) {
430 for (
const auto& xgb_tree : xgb_trees_) {
431 model.
trees.emplace_back();
438 std::queue<std::pair<int, int>> Q;
442 std::tie(old_id, new_id) = Q.front(); Q.pop();
443 const XGBTree::Node& node = xgb_tree[old_id];
444 const NodeStat stat = xgb_tree.Stat(old_id);
445 if (node.is_leaf()) {
446 const bst_float leaf_value = node.leaf_value();
447 tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
449 const bst_float split_cond = node.split_cond();
451 tree[new_id].set_numerical_split(node.split_index(),
454 treelite::Operator::kLT);
455 tree[new_id].set_gain(stat.loss_chg);
456 Q.push({node.cleft(), tree[new_id].cleft()});
457 Q.push({node.cright(), tree[new_id].cright()});
459 tree[new_id].set_sum_hess(stat.sum_hess);
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
std::vector< Tree > trees
member trees
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
ModelParam param
extra parameters
in-memory representation of a decision tree
float global_bias
global bias of the model
std::string pred_transform
name of prediction transform function
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
double tl_float
float type to be used internally
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...