10 #include <dmlc/memory_io.h> 20 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi);
27 DMLC_REGISTRY_FILE_TAG(xgboost);
30 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
31 return ParseStream(fi.get());
35 dmlc::MemoryFixedSizeStream fs(const_cast<void*>(buf), len);
36 return ParseStream(&fs);
45 typedef float bst_float;
48 class PeekableInputStream {
50 const size_t MAX_PEEK_WINDOW = 1024;
52 explicit PeekableInputStream(dmlc::Stream* fi)
53 : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
55 inline size_t Read(
void* ptr,
size_t size) {
56 const size_t bytes_buffered = BytesBuffered();
57 char* cptr =
static_cast<char*
>(ptr);
58 if (size <= bytes_buffered) {
60 if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
61 std::memcpy(cptr, &buf_[begin_ptr_], size);
64 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
65 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
66 size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
67 begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
71 const size_t bytes_to_read = size - bytes_buffered;
72 if (begin_ptr_ <= end_ptr_) {
73 std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
75 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
76 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
77 bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
79 begin_ptr_ = end_ptr_;
81 + istm_->Read(cptr + bytes_buffered, bytes_to_read);
85 inline size_t PeekRead(
void* ptr,
size_t size) {
86 CHECK_LE(size, MAX_PEEK_WINDOW)
87 <<
"PeekableInputStream allows peeking up to " 88 << MAX_PEEK_WINDOW <<
" bytes";
89 char* cptr =
static_cast<char*
>(ptr);
90 const size_t bytes_buffered = BytesBuffered();
92 if (size > bytes_buffered) {
93 const size_t bytes_to_read = size - bytes_buffered;
94 if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
95 CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
96 <<
"Failed to peek " << size <<
" bytes";
97 end_ptr_ += bytes_to_read;
99 CHECK_EQ(istm_->Read(&buf_[end_ptr_],
100 MAX_PEEK_WINDOW + 1 - end_ptr_)
101 + istm_->Read(&buf_[0],
102 bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
104 <<
"Ill-formed XGBoost model: Failed to peek " << size <<
" bytes";
105 end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
109 if (begin_ptr_ <= end_ptr_) {
110 std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
112 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
113 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
121 std::vector<char> buf_;
122 size_t begin_ptr_, end_ptr_;
124 inline size_t BytesBuffered() {
125 if (begin_ptr_ <= end_ptr_) {
126 return end_ptr_ - begin_ptr_;
128 return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
133 template <
typename T>
134 inline void CONSUME_BYTES(
const T& fi,
size_t size) {
135 static std::vector<char> dummy(500);
136 if (size > dummy.size()) dummy.resize(size);
137 CHECK_EQ(fi->Read(&dummy[0], size), size)
138 <<
"Ill-formed XGBoost model format: cannot read " << size
139 <<
" bytes from the file";
142 struct LearnerModelParam {
143 bst_float base_score;
144 unsigned num_feature;
146 int contain_extra_attrs;
147 int contain_eval_metrics;
148 uint32_t major_version;
149 uint32_t minor_version;
152 static_assert(
sizeof(LearnerModelParam) == 136,
"This is the size defined in XGBoost.");
154 struct GBTreeModelParam {
160 int num_output_group;
161 int size_leaf_vector;
171 int size_leaf_vector;
178 bst_float base_weight;
186 Node() : sindex_(0) {
188 static_assert(
sizeof(Node) == 4 *
sizeof(
int) +
sizeof(Info),
189 "Node: 64 bit align");
191 inline int cleft()
const {
194 inline int cright()
const {
195 return this->cright_;
197 inline int cdefault()
const {
198 return this->default_left() ? this->cleft() : this->cright();
200 inline unsigned split_index()
const {
201 return sindex_ & ((1U << 31) - 1U);
203 inline bool default_left()
const {
204 return (sindex_ >> 31) != 0;
206 inline bool is_leaf()
const {
209 inline bst_float leaf_value()
const {
210 return (this->info_).leaf_value;
212 inline bst_float split_cond()
const {
213 return (this->info_).split_cond;
215 inline int parent()
const {
216 return parent_ & ((1U << 31) - 1);
218 inline bool is_root()
const {
219 return parent_ == -1;
221 inline void set_leaf(bst_float value) {
222 (this->info_).leaf_value = value;
226 inline void set_split(
unsigned split_index,
227 bst_float split_cond,
228 bool default_left =
false) {
229 if (default_left) split_index |= (1U << 31);
230 this->sindex_ = split_index;
231 (this->info_).split_cond = split_cond;
235 friend class XGBTree;
237 bst_float leaf_value;
238 bst_float split_cond;
245 inline bool is_deleted()
const {
246 return sindex_ == std::numeric_limits<unsigned>::max();
248 inline void set_parent(
int pidx,
bool is_left_child =
true) {
249 if (is_left_child) pidx |= (1U << 31);
250 this->parent_ = pidx;
256 std::vector<Node> nodes;
257 std::vector<NodeStat> stats;
259 inline int AllocNode() {
260 int nd = param.num_nodes++;
261 CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
262 <<
"number of nodes in the tree exceed 2^31";
263 nodes.resize(param.num_nodes);
269 inline Node& operator[](
int nid) {
273 inline const Node& operator[](
int nid)
const {
277 inline NodeStat& Stat(
int nid) {
281 inline const NodeStat& Stat(
int nid)
const {
287 nodes[0].set_leaf(0.0f);
288 nodes[0].set_parent(-1);
290 inline void AddChilds(
int nid) {
291 int pleft = this->AllocNode();
292 int pright = this->AllocNode();
293 nodes[nid].cleft_ = pleft;
294 nodes[nid].cright_ = pright;
295 nodes[nodes[nid].cleft() ].set_parent(nid,
true);
296 nodes[nodes[nid].cright()].set_parent(nid,
false);
298 inline void Load(PeekableInputStream* fi) {
299 CHECK_EQ(fi->Read(¶m,
sizeof(TreeParam)),
sizeof(TreeParam))
300 <<
"Ill-formed XGBoost model file: can't read TreeParam";
301 nodes.resize(param.num_nodes);
302 stats.resize(param.num_nodes);
303 CHECK_NE(param.num_nodes, 0)
304 <<
"Ill-formed XGBoost model file: a tree can't be empty";
305 CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes),
sizeof(Node) * nodes.size()),
306 sizeof(Node) * nodes.size())
307 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
308 CHECK_EQ(fi->Read(dmlc::BeginPtr(stats),
sizeof(NodeStat) * stats.size()),
309 sizeof(NodeStat) * stats.size())
310 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
311 if (param.size_leaf_vector != 0) {
313 CHECK_EQ(fi->Read(&len,
sizeof(len)),
sizeof(len))
314 <<
"Ill-formed XGBoost model file";
316 CONSUME_BYTES(fi,
sizeof(bst_float) * len);
319 CHECK_EQ(param.num_roots, 1)
320 <<
"Invalid XGBoost model file: treelite does not support trees " 321 <<
"with multiple roots";
325 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
326 std::vector<XGBTree> xgb_trees_;
327 LearnerModelParam mparam_;
328 GBTreeModelParam gbm_param_;
329 std::string name_gbm_;
330 std::string name_obj_;
333 std::unique_ptr<PeekableInputStream> fp(
new PeekableInputStream(fi));
337 if (fp->PeekRead(&header[0], 4) == 4) {
338 CHECK_NE(header,
"bs64")
339 <<
"Ill-formed XGBoost model file: Base64 format no longer supported";
340 if (header ==
"binf") {
341 CONSUME_BYTES(fp, 4);
345 CHECK_EQ(fp->Read(&mparam_,
sizeof(mparam_)),
sizeof(mparam_))
346 <<
"Ill-formed XGBoost model file: corrupted header";
349 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
350 <<
"Ill-formed XGBoost model file: corrupted header";
352 name_obj_.resize(len);
353 CHECK_EQ(fp->Read(&name_obj_[0], len), len)
354 <<
"Ill-formed XGBoost model file: corrupted header";
360 CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
361 <<
"Ill-formed XGBoost model file: corrupted header";
362 name_gbm_.resize(len);
364 CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
365 <<
"Ill-formed XGBoost model file: corrupted header";
370 CHECK(name_gbm_ ==
"gbtree" || name_gbm_ ==
"dart")
371 <<
"Invalid XGBoost model file: " 372 <<
"Gradient booster must be gbtree or dart type.";
374 CHECK_EQ(fp->Read(&gbm_param_,
sizeof(gbm_param_)),
sizeof(gbm_param_))
375 <<
"Invalid XGBoost model file: corrupted GBTree parameters";
376 for (
int i = 0; i < gbm_param_.num_trees; ++i) {
377 xgb_trees_.emplace_back();
378 xgb_trees_.back().Load(fp.get());
380 CHECK_EQ(gbm_param_.num_roots, 1) <<
"multi-root trees not supported";
382 std::vector<int> tree_info;
383 tree_info.resize(gbm_param_.num_trees);
384 if (gbm_param_.num_trees != 0) {
385 CHECK_EQ(fp->Read(dmlc::BeginPtr(tree_info),
sizeof(int32_t) * tree_info.size()),
386 sizeof(int32_t) * tree_info.size());
389 std::vector<bst_float> weight_drop;
390 if (name_gbm_ ==
"dart") {
391 weight_drop.resize(gbm_param_.num_trees);
392 if (gbm_param_.num_trees != 0) {
393 fi->Read(&weight_drop);
398 std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<float, float>();
400 model->
num_feature =
static_cast<int>(mparam_.num_feature);
401 model->average_tree_output =
false;
402 const int num_class = std::max(mparam_.num_class, 1);
405 model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
406 model->task_param.grove_per_class =
true;
409 model->task_type = treelite::TaskType::kBinaryClfRegr;
410 model->task_param.grove_per_class =
false;
412 model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat;
413 model->task_param.num_class = num_class;
414 model->task_param.leaf_vector_size = 1;
417 treelite::details::xgboost::SetPredTransform(name_obj_, &model->param);
420 model->param.global_bias =
static_cast<float>(mparam_.base_score);
423 const bool need_transform_to_margin = mparam_.major_version >= 1;
424 if (need_transform_to_margin) {
425 treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param);
429 for (
const auto& xgb_tree : xgb_trees_) {
430 model->trees.emplace_back();
437 std::queue<std::pair<int, int>> Q;
441 std::tie(old_id, new_id) = Q.front(); Q.pop();
442 const XGBTree::Node& node = xgb_tree[old_id];
443 const NodeStat stat = xgb_tree.Stat(old_id);
444 if (node.is_leaf()) {
445 bst_float leaf_value = node.leaf_value();
447 if (!weight_drop.empty()) {
448 leaf_value *= weight_drop[model->trees.size() - 1];
450 tree.
SetLeaf(new_id, static_cast<float>(leaf_value));
452 const bst_float split_cond = node.split_cond();
455 static_cast<float>(split_cond), node.default_left(), treelite::Operator::kLT);
456 tree.
SetGain(new_id, stat.loss_chg);
457 Q.push({node.cleft(), tree.
LeftChild(new_id)});
458 Q.push({node.cright(), tree.
RightChild(new_id)});
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
model structure for tree ensemble
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
void SetGain(int nid, double gain)
set the gain value of the node
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
Helper functions for loading XGBoost models.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node