9 #include <dmlc/memory_io.h> 19 const char* name_obj);
26 DMLC_REGISTRY_FILE_TAG(xgboost);
29 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
30 return ParseStream(fi.get());
34 const char* name_obj) {
35 std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(filename,
"w"));
36 SaveModelToStream(fo.get(), model, name_obj);
40 dmlc::MemoryFixedSizeStream fs((
void*)buf, len);
41 return ParseStream(&fs);
50 typedef float bst_float;
53 class PeekableInputStream {
55 const size_t MAX_PEEK_WINDOW = 1024;
57 PeekableInputStream(dmlc::Stream* fi)
58 : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
60 inline size_t Read(
void* ptr,
size_t size) {
61 const size_t bytes_buffered = BytesBuffered();
62 char* cptr =
static_cast<char*
>(ptr);
63 if (size <= bytes_buffered) {
65 if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
66 std::memcpy(cptr, &buf_[begin_ptr_], size);
69 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
70 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
71 size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
72 begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
76 const size_t bytes_to_read = size - bytes_buffered;
77 if (begin_ptr_ <= end_ptr_) {
78 std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
80 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
81 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
82 bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
84 begin_ptr_ = end_ptr_;
86 + istm_->Read(cptr + bytes_buffered, bytes_to_read);
90 inline size_t PeekRead(
void* ptr,
size_t size) {
91 CHECK_LE(size, MAX_PEEK_WINDOW)
92 <<
"PeekableInputStream allows peeking up to " 93 << MAX_PEEK_WINDOW <<
" bytes";
94 char* cptr =
static_cast<char*
>(ptr);
95 const size_t bytes_buffered = BytesBuffered();
97 if (size > bytes_buffered) {
98 const size_t bytes_to_read = size - bytes_buffered;
99 if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
100 CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
101 <<
"Failed to peek " << size <<
" bytes";
102 end_ptr_ += bytes_to_read;
104 CHECK_EQ( istm_->Read(&buf_[end_ptr_],
105 MAX_PEEK_WINDOW + 1 - end_ptr_)
106 + istm_->Read(&buf_[0],
107 bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
109 <<
"Ill-formed XGBoost model: Failed to peek " << size <<
" bytes";
110 end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
114 if (begin_ptr_ <= end_ptr_) {
115 std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
117 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
118 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
126 std::vector<char> buf_;
127 size_t begin_ptr_, end_ptr_;
129 inline size_t BytesBuffered() {
130 if (begin_ptr_ <= end_ptr_) {
131 return end_ptr_ - begin_ptr_;
133 return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
138 template <
typename T>
139 inline void CONSUME_BYTES(
const T& fi,
size_t size) {
140 static std::vector<char> dummy(500);
141 if (size > dummy.size()) dummy.resize(size);
142 CHECK_EQ(fi->Read(&dummy[0], size), size)
143 <<
"Ill-formed XGBoost model format: cannot read " << size
144 <<
" bytes from the file";
147 struct LearnerModelParam {
148 bst_float base_score;
149 unsigned num_feature;
151 int contain_extra_attrs;
152 int contain_eval_metrics;
156 struct GBTreeModelParam {
162 int num_output_group;
163 int size_leaf_vector;
173 int size_leaf_vector;
180 bst_float base_weight;
188 Node() : sindex_(0) {
190 static_assert(
sizeof(Node) == 4 *
sizeof(
int) +
sizeof(Info),
191 "Node: 64 bit align");
193 inline int cleft()
const {
196 inline int cright()
const {
197 return this->cright_;
199 inline int cdefault()
const {
200 return this->default_left() ? this->cleft() : this->cright();
202 inline unsigned split_index()
const {
203 return sindex_ & ((1U << 31) - 1U);
205 inline bool default_left()
const {
206 return (sindex_ >> 31) != 0;
208 inline bool is_leaf()
const {
211 inline bst_float leaf_value()
const {
212 return (this->info_).leaf_value;
214 inline bst_float split_cond()
const {
215 return (this->info_).split_cond;
217 inline int parent()
const {
218 return parent_ & ((1U << 31) - 1);
220 inline bool is_root()
const {
221 return parent_ == -1;
223 inline void set_leaf(bst_float value) {
224 (this->info_).leaf_value = value;
228 inline void set_split(
unsigned split_index,
229 bst_float split_cond,
230 bool default_left =
false) {
231 if (default_left) split_index |= (1U << 31);
232 this->sindex_ = split_index;
233 (this->info_).split_cond = split_cond;
237 friend class XGBTree;
239 bst_float leaf_value;
240 bst_float split_cond;
247 inline bool is_deleted()
const {
248 return sindex_ == std::numeric_limits<unsigned>::max();
250 inline void set_parent(
int pidx,
bool is_left_child =
true) {
251 if (is_left_child) pidx |= (1U << 31);
252 this->parent_ = pidx;
258 std::vector<Node> nodes;
260 inline int AllocNode() {
261 int nd = param.num_nodes++;
262 CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
263 <<
"number of nodes in the tree exceed 2^31";
264 nodes.resize(param.num_nodes);
269 inline Node& operator[](
int nid) {
272 inline const Node& operator[](
int nid)
const {
278 nodes[0].set_leaf(0.0f);
279 nodes[0].set_parent(-1);
281 inline void AddChilds(
int nid) {
282 int pleft = this->AllocNode();
283 int pright = this->AllocNode();
284 nodes[nid].cleft_ = pleft;
285 nodes[nid].cright_ = pright;
286 nodes[nodes[nid].cleft() ].set_parent(nid,
true);
287 nodes[nodes[nid].cright()].set_parent(nid,
false);
289 inline void Load(PeekableInputStream* fi) {
290 CHECK_EQ(fi->Read(¶m,
sizeof(TreeParam)),
sizeof(TreeParam))
291 <<
"Ill-formed XGBoost model file: can't read TreeParam";
292 nodes.resize(param.num_nodes);
293 CHECK_NE(param.num_nodes, 0)
294 <<
"Ill-formed XGBoost model file: a tree can't be empty";
295 CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes),
sizeof(Node) * nodes.size()),
296 sizeof(Node) * nodes.size())
297 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
298 CONSUME_BYTES(fi, (3 *
sizeof(bst_float) +
sizeof(
int)) * param.num_nodes);
299 if (param.size_leaf_vector != 0) {
301 CHECK_EQ(fi->Read(&len,
sizeof(len)),
sizeof(len))
302 <<
"Ill-formed XGBoost model file";
304 CONSUME_BYTES(fi,
sizeof(bst_float) * len);
307 CHECK_EQ(param.num_roots, 1)
308 <<
"Invalid XGBoost model file: treelite does not support trees " 309 <<
"with multiple roots";
311 inline void Save(dmlc::Stream* fo,
int num_feature)
const {
313 const bst_float nan = std::numeric_limits<bst_float>::quiet_NaN();
314 std::vector<NodeStat> stats_(nodes.size(), NodeStat{nan, nan, nan, -1});
315 param_.num_roots = 1;
316 param_.num_nodes =
static_cast<int>(nodes.size());
317 param_.num_deleted = 0;
318 std::function<int(int)> max_depth_func;
319 max_depth_func = [&max_depth_func,
this](
int nid) ->
int {
320 if (nodes[nid].is_leaf()) {
323 return 1 + std::max(max_depth_func(nodes[nid].cleft()),
324 max_depth_func(nodes[nid].cright()));
327 param_.max_depth = max_depth_func(0);
328 param_.num_feature = num_feature;
329 param_.size_leaf_vector = 0;
330 fo->Write(¶m_,
sizeof(TreeParam));
331 fo->Write(dmlc::BeginPtr(nodes),
sizeof(Node) * nodes.size());
333 fo->Write(dmlc::BeginPtr(stats_),
sizeof(NodeStat) * nodes.size());
338 std::vector<XGBTree> xgb_trees_;
339 LearnerModelParam mparam_;
340 GBTreeModelParam gbm_param_;
341 std::string name_gbm_;
342 std::string name_obj_;
345 std::unique_ptr<PeekableInputStream> fp(
new PeekableInputStream(fi));
349 if (fp->PeekRead(&header[0], 4) == 4) {
350 CHECK_NE(header,
"bs64")
351 <<
"Ill-formed XGBoost model file: Base64 format no longer supported";
352 if (header ==
"binf") {
353 CONSUME_BYTES(fp, 4);
357 CHECK_EQ(fp->Read(&mparam_,
sizeof(mparam_)),
sizeof(mparam_))
358 <<
"Ill-formed XGBoost model file: corrupted header";
359 LOG(INFO) <<
"Global bias of the model: " << mparam_.base_score;
364 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
365 <<
"Ill-formed XGBoost model file: corrupted header";
366 if (len >= std::numeric_limits<unsigned>::max()) {
368 CHECK_EQ(fp->Read(&gap,
sizeof(gap)),
sizeof(gap))
369 <<
"Ill-formed XGBoost model file: corrupted header";
370 len = len >>
static_cast<uint64_t
>(32UL);
373 name_obj_.resize(len);
374 CHECK_EQ(fp->Read(&name_obj_[0], len), len)
375 <<
"Ill-formed XGBoost model file: corrupted header";
381 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
382 <<
"Ill-formed XGBoost model file: corrupted header";
383 name_gbm_.resize(len);
385 CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
386 <<
"Ill-formed XGBoost model file: corrupted header";
391 CHECK_EQ(name_gbm_,
"gbtree")
392 <<
"Invalid XGBoost model file: " 393 <<
"Gradient booster must be gbtree type.";
395 CHECK_EQ(fp->Read(&gbm_param_,
sizeof(gbm_param_)),
sizeof(gbm_param_))
396 <<
"Invalid XGBoost model file: corrupted GBTree parameters";
397 LOG(INFO) <<
"gbm_param_.num_feature = " << gbm_param_.num_feature;
398 LOG(INFO) <<
"gbm_param_.num_output_group = " << gbm_param_.num_output_group;
399 for (
int i = 0; i < gbm_param_.num_trees; ++i) {
400 xgb_trees_.emplace_back();
401 xgb_trees_.back().Load(fp.get());
403 CHECK_EQ(gbm_param_.num_roots, 1) <<
"multi-root trees not supported";
415 if (name_obj_ ==
"multi:softmax") {
417 }
else if (name_obj_ ==
"multi:softprob") {
419 }
else if (name_obj_ ==
"reg:logistic" || name_obj_ ==
"binary:logistic") {
422 }
else if (name_obj_ ==
"count:poisson" || name_obj_ ==
"reg:gamma" 423 || name_obj_ ==
"reg:tweedie") {
430 for (
const auto& xgb_tree : xgb_trees_) {
431 model.
trees.emplace_back();
438 std::queue<std::pair<int, int>> Q;
442 std::tie(old_id, new_id) = Q.front(); Q.pop();
443 const XGBTree::Node& node = xgb_tree[old_id];
444 if (node.is_leaf()) {
445 const bst_float leaf_value = node.leaf_value();
446 tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
448 const bst_float split_cond = node.split_cond();
450 tree[new_id].set_numerical_split(node.split_index(),
453 treelite::Operator::kLT);
454 Q.push({node.cleft(), tree[new_id].cleft()});
455 Q.push({node.cright(), tree[new_id].cright()});
462 inline void SaveModelToStream(dmlc::Stream* fo,
const treelite::Model& model,
463 const char* name_obj) {
464 LearnerModelParam mparam_;
465 GBTreeModelParam gbm_param_;
470 mparam_.contain_extra_attrs = 0;
471 mparam_.contain_eval_metrics = 0;
472 fo->Write(&mparam_,
sizeof(LearnerModelParam));
474 const std::string name_gbm_ =
"gbtree";
475 fo->Write(std::string(name_obj));
476 fo->Write(name_gbm_);
478 gbm_param_.num_trees = model.
trees.size();
479 gbm_param_.num_roots = 1;
482 gbm_param_.size_leaf_vector = 0;
483 fo->Write(&gbm_param_,
sizeof(gbm_param_));
488 std::queue<std::pair<int, int>> Q;
492 std::tie(old_id, new_id) = Q.front(); Q.pop();
496 xgb_tree_[new_id].set_leaf(static_cast<bst_float>(leaf_value));
499 xgb_tree_.AddChilds(new_id);
501 <<
"Comparison operator must be `<`";
503 static_cast<bst_float
>(split_cond),
505 Q.push({node.
cleft(), xgb_tree_[new_id].cleft()});
506 Q.push({node.
cright(), xgb_tree_[new_id].cright()});
512 std::vector<int> tree_info_(model.
trees.size(), 0);
514 for (
size_t i = 0; i < model.
trees.size(); ++i) {
518 fo->Write(dmlc::BeginPtr(tree_info_),
sizeof(
int) * tree_info_.size());
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::vector< Tree > trees
member trees
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
unsigned split_index() const
feature index of split condition
ModelParam param
extra parameters
Operator comparison_op() const
get comparison operator
in-memory representation of a decision tree
float global_bias
global bias of the model
tl_float threshold() const
int cright() const
index of right child
std::string pred_transform
name of prediction transform function
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
tl_float leaf_value() const
bool default_left() const
when feature is unknown, whether goes to left child
void AddChilds(int nid)
add child nodes to node
int cleft() const
index of left child
void ExportXGBoostModel(const char *filename, const Model &model, const char *name_obj)
export a model in XGBoost format. The exported model can be read by XGBoost (dmlc/xgboost).
bool is_leaf() const
whether current node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...