9 #include <dmlc/memory_io.h> 24 DMLC_REGISTRY_FILE_TAG(xgboost);
27 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
28 return ParseStream(fi.get());
32 dmlc::MemoryFixedSizeStream fs((
void*)buf, len);
33 return ParseStream(&fs);
42 typedef float bst_float;
45 class PeekableInputStream {
47 const size_t MAX_PEEK_WINDOW = 1024;
49 PeekableInputStream(dmlc::Stream* fi)
50 : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
52 inline size_t Read(
void* ptr,
size_t size) {
53 const size_t bytes_buffered = BytesBuffered();
54 char* cptr =
static_cast<char*
>(ptr);
55 if (size <= bytes_buffered) {
57 if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
58 std::memcpy(cptr, &buf_[begin_ptr_], size);
61 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
62 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
63 size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
64 begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
68 const size_t bytes_to_read = size - bytes_buffered;
69 if (begin_ptr_ <= end_ptr_) {
70 std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
72 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
73 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
74 bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
76 begin_ptr_ = end_ptr_;
78 + istm_->Read(cptr + bytes_buffered, bytes_to_read);
82 inline size_t PeekRead(
void* ptr,
size_t size) {
83 CHECK_LE(size, MAX_PEEK_WINDOW)
84 <<
"PeekableInputStream allows peeking up to " 85 << MAX_PEEK_WINDOW <<
" bytes";
86 char* cptr =
static_cast<char*
>(ptr);
87 const size_t bytes_buffered = BytesBuffered();
89 if (size > bytes_buffered) {
90 const size_t bytes_to_read = size - bytes_buffered;
91 if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
92 CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
93 <<
"Failed to peek " << size <<
" bytes";
94 end_ptr_ += bytes_to_read;
96 CHECK_EQ( istm_->Read(&buf_[end_ptr_],
97 MAX_PEEK_WINDOW + 1 - end_ptr_)
98 + istm_->Read(&buf_[0],
99 bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
101 <<
"Ill-formed XGBoost model: Failed to peek " << size <<
" bytes";
102 end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
106 if (begin_ptr_ <= end_ptr_) {
107 std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
109 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
110 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
118 std::vector<char> buf_;
119 size_t begin_ptr_, end_ptr_;
121 inline size_t BytesBuffered() {
122 if (begin_ptr_ <= end_ptr_) {
123 return end_ptr_ - begin_ptr_;
125 return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
130 template <
typename T>
131 inline void CONSUME_BYTES(
const T& fi,
size_t size) {
132 static std::vector<char> dummy(500);
133 if (size > dummy.size()) dummy.resize(size);
134 CHECK_EQ(fi->Read(&dummy[0], size), size)
135 <<
"Ill-formed XGBoost model format: cannot read " << size
136 <<
" bytes from the file";
139 struct LearnerModelParam {
140 bst_float base_score;
145 struct GBTreeModelParam {
151 int num_output_group;
160 int size_leaf_vector;
168 Node() : sindex_(0) {
170 static_assert(
sizeof(Node) == 4 *
sizeof(
int) +
sizeof(Info),
171 "Node: 64 bit align");
173 inline int cleft()
const {
176 inline int cright()
const {
177 return this->cright_;
179 inline int cdefault()
const {
180 return this->default_left() ? this->cleft() : this->cright();
182 inline unsigned split_index()
const {
183 return sindex_ & ((1U << 31) - 1U);
185 inline bool default_left()
const {
186 return (sindex_ >> 31) != 0;
188 inline bool is_leaf()
const {
191 inline bst_float leaf_value()
const {
192 return (this->info_).leaf_value;
194 inline bst_float split_cond()
const {
195 return (this->info_).split_cond;
197 inline int parent()
const {
198 return parent_ & ((1U << 31) - 1);
200 inline bool is_root()
const {
201 return parent_ == -1;
205 friend class XGBTree;
207 bst_float leaf_value;
208 bst_float split_cond;
215 inline bool is_deleted()
const {
216 return sindex_ == std::numeric_limits<unsigned>::max();
222 std::vector<Node> nodes;
225 inline Node& operator[](
int nid) {
228 inline const Node& operator[](
int nid)
const {
231 inline void Load(PeekableInputStream* fi) {
232 CHECK_EQ(fi->Read(¶m,
sizeof(TreeParam)),
sizeof(TreeParam))
233 <<
"Ill-formed XGBoost model file: can't read TreeParam";
234 nodes.resize(param.num_nodes);
235 CHECK_NE(param.num_nodes, 0)
236 <<
"Ill-formed XGBoost model file: a tree can't be empty";
237 CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes),
sizeof(Node) * nodes.size()),
238 sizeof(Node) * nodes.size())
239 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
240 CONSUME_BYTES(fi, (3 *
sizeof(bst_float) +
sizeof(
int)) * param.num_nodes);
241 if (param.size_leaf_vector != 0) {
243 CHECK_EQ(fi->Read(&len,
sizeof(len)),
sizeof(len))
244 <<
"Ill-formed XGBoost model file";
246 CONSUME_BYTES(fi,
sizeof(bst_float) * len);
249 CHECK_EQ(param.num_roots, 1)
250 <<
"Invalid XGBoost model file: treelite does not support trees " 251 <<
"with multiple roots";
256 std::vector<XGBTree> xgb_trees_;
257 LearnerModelParam mparam_;
258 GBTreeModelParam gbm_param_;
259 std::string name_gbm_;
260 std::string name_obj_;
263 std::unique_ptr<PeekableInputStream> fp(
new PeekableInputStream(fi));
267 if (fp->PeekRead(&header[0], 4) == 4) {
268 CHECK_NE(header,
"bs64")
269 <<
"Ill-formed XGBoost model file: Base64 format no longer supported";
270 if (header ==
"binf") {
271 CONSUME_BYTES(fp, 4);
275 CHECK_EQ(fp->Read(&mparam_,
sizeof(mparam_)),
sizeof(mparam_))
276 <<
"Ill-formed XGBoost model file: corrupted header";
277 LOG(INFO) <<
"Global bias of the model: " << mparam_.base_score;
282 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
283 <<
"Ill-formed XGBoost model file: corrupted header";
284 if (len >= std::numeric_limits<unsigned>::max()) {
286 CHECK_EQ(fp->Read(&gap,
sizeof(gap)),
sizeof(gap))
287 <<
"Ill-formed XGBoost model file: corrupted header";
288 len = len >>
static_cast<uint64_t
>(32UL);
291 name_obj_.resize(len);
292 CHECK_EQ(fp->Read(&name_obj_[0], len), len)
293 <<
"Ill-formed XGBoost model file: corrupted header";
299 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
300 <<
"Ill-formed XGBoost model file: corrupted header";
301 name_gbm_.resize(len);
303 CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
304 <<
"Ill-formed XGBoost model file: corrupted header";
309 CHECK_EQ(name_gbm_,
"gbtree")
310 <<
"Invalid XGBoost model file: " 311 <<
"Gradient booster must be gbtree type.";
313 CHECK_EQ(fp->Read(&gbm_param_,
sizeof(gbm_param_)),
sizeof(gbm_param_))
314 <<
"Invalid XGBoost model file: corrupted GBTree parameters";
315 LOG(INFO) <<
"gbm_param_.num_feature = " << gbm_param_.num_feature;
316 LOG(INFO) <<
"gbm_param_.num_output_group = " << gbm_param_.num_output_group;
317 for (
int i = 0; i < gbm_param_.num_trees; ++i) {
318 xgb_trees_.emplace_back();
319 xgb_trees_.back().Load(fp.get());
332 if (name_obj_ ==
"multi:softmax") {
334 }
else if (name_obj_ ==
"multi:softprob") {
336 }
else if (name_obj_ ==
"reg:logistic" || name_obj_ ==
"binary:logistic") {
339 }
else if (name_obj_ ==
"count:poisson" || name_obj_ ==
"reg:gamma" 340 || name_obj_ ==
"reg:tweedie") {
347 for (
const auto& xgb_tree : xgb_trees_) {
348 model.
trees.emplace_back();
355 std::queue<std::pair<int, int>> Q;
359 std::tie(old_id, new_id) = Q.front(); Q.pop();
360 const XGBTree::Node& node = xgb_tree[old_id];
361 if (node.is_leaf()) {
362 const bst_float leaf_value = node.leaf_value();
363 tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
365 const bst_float split_cond = node.split_cond();
367 tree[new_id].set_numerical_split(node.split_index(),
370 treelite::Operator::kLT);
371 Q.push({node.cleft(), tree[new_id].cleft()});
372 Q.push({node.cright(), tree[new_id].cright()});
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::vector< Tree > trees
member trees
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
ModelParam param
extra parameters
in-memory representation of a decision tree
float global_bias
global bias of the model
std::string pred_transform
name of prediction transform function
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...