22 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
30 std::ifstream fi(filename, std::ios::in | std::ios::binary);
31 return ParseStream(fi);
35 std::istringstream fi(std::string(static_cast<const char*>(buf), len));
36 return ParseStream(fi);
45 typedef float bst_float;
48 class PeekableInputStream {
50 const size_t MAX_PEEK_WINDOW = 1024;
52 explicit PeekableInputStream(std::istream& fi)
53 : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
55 inline size_t Read(
void* ptr,
size_t size) {
56 const size_t bytes_buffered = BytesBuffered();
57 char* cptr =
static_cast<char*
>(ptr);
58 if (size <= bytes_buffered) {
60 if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
61 std::memcpy(cptr, &buf_[begin_ptr_], size);
64 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
65 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
66 size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
67 begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
71 const size_t bytes_to_read = size - bytes_buffered;
72 if (begin_ptr_ <= end_ptr_) {
73 std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
75 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
76 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
77 bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
79 begin_ptr_ = end_ptr_;
80 istm_.read(cptr + bytes_buffered, bytes_to_read);
81 return bytes_buffered + istm_.gcount();
85 inline size_t PeekRead(
void* ptr,
size_t size) {
86 TREELITE_CHECK_LE(size, MAX_PEEK_WINDOW)
87 <<
"PeekableInputStream allows peeking up to " 88 << MAX_PEEK_WINDOW <<
" bytes";
89 char* cptr =
static_cast<char*
>(ptr);
90 const size_t bytes_buffered = BytesBuffered();
92 if (size > bytes_buffered) {
93 const size_t bytes_to_read = size - bytes_buffered;
94 if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
95 istm_.read(&buf_[end_ptr_], bytes_to_read);
96 TREELITE_CHECK_EQ(istm_.gcount(), bytes_to_read)
97 <<
"Failed to peek " << size <<
" bytes";
98 end_ptr_ += bytes_to_read;
100 istm_.read(&buf_[end_ptr_], MAX_PEEK_WINDOW + 1 - end_ptr_);
101 size_t first_read = istm_.gcount();
102 istm_.read(&buf_[0], bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1);
103 size_t second_read = istm_.gcount();
104 TREELITE_CHECK_EQ(first_read + second_read, bytes_to_read)
105 <<
"Ill-formed XGBoost model: Failed to peek " << size <<
" bytes";
106 end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
110 if (begin_ptr_ <= end_ptr_) {
111 std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
113 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
114 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
122 std::vector<char> buf_;
123 size_t begin_ptr_, end_ptr_;
125 inline size_t BytesBuffered() {
126 if (begin_ptr_ <= end_ptr_) {
127 return end_ptr_ - begin_ptr_;
129 return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
134 template <
typename T>
135 inline void CONSUME_BYTES(
const T& fi,
size_t size) {
136 static std::vector<char> dummy(500);
137 if (size > dummy.size()) dummy.resize(size);
138 TREELITE_CHECK_EQ(fi->Read(&dummy[0], size), size)
139 <<
"Ill-formed XGBoost model format: cannot read " << size
140 <<
" bytes from the file";
143 struct LearnerModelParam {
144 bst_float base_score;
145 unsigned num_feature;
147 int contain_extra_attrs;
148 int contain_eval_metrics;
149 uint32_t major_version;
150 uint32_t minor_version;
153 static_assert(
sizeof(LearnerModelParam) == 136,
"This is the size defined in XGBoost.");
155 struct GBTreeModelParam {
161 int num_output_group;
162 int size_leaf_vector;
172 int size_leaf_vector;
179 bst_float base_weight;
187 Node() : sindex_(0) {
189 static_assert(
sizeof(Node) == 4 *
sizeof(
int) +
sizeof(Info),
190 "Node: 64 bit align");
192 inline int cleft()
const {
195 inline int cright()
const {
196 return this->cright_;
198 inline int cdefault()
const {
199 return this->default_left() ? this->cleft() : this->cright();
201 inline unsigned split_index()
const {
202 return sindex_ & ((1U << 31) - 1U);
204 inline bool default_left()
const {
205 return (sindex_ >> 31) != 0;
207 inline bool is_leaf()
const {
210 inline bst_float leaf_value()
const {
211 return (this->info_).leaf_value;
213 inline bst_float split_cond()
const {
214 return (this->info_).split_cond;
216 inline int parent()
const {
217 return parent_ & ((1U << 31) - 1);
219 inline bool is_root()
const {
220 return parent_ == -1;
222 inline void set_leaf(bst_float value) {
223 (this->info_).leaf_value = value;
227 inline void set_split(
unsigned split_index,
228 bst_float split_cond,
229 bool default_left =
false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
236 friend class XGBTree;
238 bst_float leaf_value;
239 bst_float split_cond;
246 inline bool is_deleted()
const {
247 return sindex_ == std::numeric_limits<unsigned>::max();
249 inline void set_parent(
int pidx,
bool is_left_child =
true) {
250 if (is_left_child) pidx |= (1U << 31);
251 this->parent_ = pidx;
257 std::vector<Node> nodes;
258 std::vector<NodeStat> stats;
260 inline int AllocNode() {
261 int nd = param.num_nodes++;
262 TREELITE_CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
263 <<
"number of nodes in the tree exceed 2^31";
264 nodes.resize(param.num_nodes);
270 inline Node& operator[](
int nid) {
274 inline const Node& operator[](
int nid)
const {
278 inline NodeStat& Stat(
int nid) {
282 inline const NodeStat& Stat(
int nid)
const {
288 nodes[0].set_leaf(0.0f);
289 nodes[0].set_parent(-1);
291 inline void AddChilds(
int nid) {
292 int pleft = this->AllocNode();
293 int pright = this->AllocNode();
294 nodes[nid].cleft_ = pleft;
295 nodes[nid].cright_ = pright;
296 nodes[nodes[nid].cleft() ].set_parent(nid,
true);
297 nodes[nodes[nid].cright()].set_parent(nid,
false);
299 inline void Load(PeekableInputStream* fi) {
300 TREELITE_CHECK_EQ(fi->Read(¶m,
sizeof(TreeParam)),
sizeof(TreeParam))
301 <<
"Ill-formed XGBoost model file: can't read TreeParam";
302 TREELITE_CHECK_GT(param.num_nodes, 0)
303 <<
"Ill-formed XGBoost model file: a tree can't be empty";
304 nodes.resize(param.num_nodes);
305 stats.resize(param.num_nodes);
306 TREELITE_CHECK_EQ(fi->Read(nodes.data(),
sizeof(Node) * nodes.size()),
307 sizeof(Node) * nodes.size())
308 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
309 TREELITE_CHECK_EQ(fi->Read(stats.data(),
sizeof(NodeStat) * stats.size()),
310 sizeof(NodeStat) * stats.size())
311 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
312 if (param.size_leaf_vector != 0) {
314 TREELITE_CHECK_EQ(fi->Read(&len,
sizeof(len)),
sizeof(len))
315 <<
"Ill-formed XGBoost model file";
317 CONSUME_BYTES(fi,
sizeof(bst_float) * len);
320 TREELITE_CHECK_EQ(param.num_roots, 1)
321 <<
"Invalid XGBoost model file: treelite does not support trees " 322 <<
"with multiple roots";
326 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
327 std::vector<XGBTree> xgb_trees_;
328 LearnerModelParam mparam_;
329 GBTreeModelParam gbm_param_;
330 std::string name_gbm_;
331 std::string name_obj_;
334 std::unique_ptr<PeekableInputStream> fp(
new PeekableInputStream(fi));
338 if (fp->PeekRead(&header[0], 4) == 4) {
339 TREELITE_CHECK_NE(header,
"bs64")
340 <<
"Ill-formed XGBoost model file: Base64 format no longer supported";
341 if (header ==
"binf") {
342 CONSUME_BYTES(fp, 4);
346 TREELITE_CHECK_EQ(fp->Read(&mparam_,
sizeof(mparam_)),
sizeof(mparam_))
347 <<
"Ill-formed XGBoost model file: corrupted header";
350 TREELITE_CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
351 <<
"Ill-formed XGBoost model file: corrupted header";
353 name_obj_.resize(len);
354 TREELITE_CHECK_EQ(fp->Read(&name_obj_[0], len), len)
355 <<
"Ill-formed XGBoost model file: corrupted header";
361 TREELITE_CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
362 <<
"Ill-formed XGBoost model file: corrupted header";
363 name_gbm_.resize(len);
365 TREELITE_CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
366 <<
"Ill-formed XGBoost model file: corrupted header";
371 TREELITE_CHECK(name_gbm_ ==
"gbtree" || name_gbm_ ==
"dart")
372 <<
"Invalid XGBoost model file: " 373 <<
"Gradient booster must be gbtree or dart type.";
375 TREELITE_CHECK_EQ(fp->Read(&gbm_param_,
sizeof(gbm_param_)),
sizeof(gbm_param_))
376 <<
"Invalid XGBoost model file: corrupted GBTree parameters";
377 TREELITE_CHECK_GE(gbm_param_.num_trees, 0)
378 <<
"Invalid XGBoost model file: num_trees must be 0 or greater";
379 for (
int i = 0; i < gbm_param_.num_trees; ++i) {
380 xgb_trees_.emplace_back();
381 xgb_trees_.back().Load(fp.get());
383 TREELITE_CHECK_EQ(gbm_param_.num_roots, 1) <<
"multi-root trees not supported";
385 std::vector<int> tree_info;
386 tree_info.resize(gbm_param_.num_trees);
387 if (gbm_param_.num_trees > 0) {
388 TREELITE_CHECK_EQ(fp->Read(tree_info.data(),
sizeof(int32_t) * tree_info.size()),
389 sizeof(int32_t) * tree_info.size());
392 std::vector<bst_float> weight_drop;
393 if (name_gbm_ ==
"dart") {
394 weight_drop.resize(gbm_param_.num_trees);
396 fi.read(reinterpret_cast<char*>(&sz),
sizeof(uint64_t));
397 TREELITE_CHECK_EQ(sz, gbm_param_.num_trees);
398 if (gbm_param_.num_trees != 0) {
399 for (uint64_t i = 0; i < sz; ++i) {
400 fi.read(reinterpret_cast<char*>(&weight_drop[i]),
sizeof(bst_float));
406 std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<float, float>();
408 model->
num_feature =
static_cast<int>(mparam_.num_feature);
409 model->average_tree_output =
false;
410 const int num_class = std::max(mparam_.num_class, 1);
413 model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
414 model->task_param.grove_per_class =
true;
417 model->task_type = treelite::TaskType::kBinaryClfRegr;
418 model->task_param.grove_per_class =
false;
420 model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
421 model->task_param.num_class = num_class;
422 model->task_param.leaf_vector_size = 1;
425 treelite::details::xgboost::SetPredTransform(name_obj_, &model->param);
428 model->param.global_bias =
static_cast<float>(mparam_.base_score);
431 const bool need_transform_to_margin = mparam_.major_version >= 1;
432 if (need_transform_to_margin) {
433 treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param);
437 for (
const auto& xgb_tree : xgb_trees_) {
438 model->trees.emplace_back();
445 std::queue<std::pair<int, int>> Q;
449 std::tie(old_id, new_id) = Q.front(); Q.pop();
450 const XGBTree::Node& node = xgb_tree[old_id];
451 const NodeStat stat = xgb_tree.Stat(old_id);
452 if (node.is_leaf()) {
453 bst_float leaf_value = node.leaf_value();
455 if (!weight_drop.empty()) {
456 leaf_value *= weight_drop[model->trees.size() - 1];
458 tree.
SetLeaf(new_id, static_cast<float>(leaf_value));
460 const bst_float split_cond = node.split_cond();
463 static_cast<float>(split_cond), node.default_left(), treelite::Operator::kLT);
464 tree.
SetGain(new_id, stat.loss_chg);
465 Q.push({node.cleft(), tree.
LeftChild(new_id)});
466 Q.push({node.cright(), tree.
RightChild(new_id)});
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
model structure for tree ensemble
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
logging facility for Treelite
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
void SetGain(int nid, double gain)
set the gain value of the node
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
Helper functions for loading XGBoost models.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node