22 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
30 std::ifstream fi(filename, std::ios::in | std::ios::binary);
31 return ParseStream(fi);
35 std::istringstream fi(std::string(static_cast<const char*>(buf), len));
36 return ParseStream(fi);
45 typedef float bst_float;
48 class PeekableInputStream {
50 const size_t MAX_PEEK_WINDOW = 1024;
52 explicit PeekableInputStream(std::istream& fi)
53 : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
55 inline size_t Read(
void* ptr,
size_t size) {
56 const size_t bytes_buffered = BytesBuffered();
57 char* cptr =
static_cast<char*
>(ptr);
58 if (size <= bytes_buffered) {
60 if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
61 std::memcpy(cptr, &buf_[begin_ptr_], size);
64 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
65 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
66 size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
67 begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
71 const size_t bytes_to_read = size - bytes_buffered;
72 if (begin_ptr_ <= end_ptr_) {
73 std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
75 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
76 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
77 bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
79 begin_ptr_ = end_ptr_;
80 istm_.read(cptr + bytes_buffered, bytes_to_read);
81 return bytes_buffered + istm_.gcount();
85 inline size_t PeekRead(
void* ptr,
size_t size) {
86 TREELITE_CHECK_LE(size, MAX_PEEK_WINDOW)
87 <<
"PeekableInputStream allows peeking up to " 88 << MAX_PEEK_WINDOW <<
" bytes";
89 char* cptr =
static_cast<char*
>(ptr);
90 const size_t bytes_buffered = BytesBuffered();
92 if (size > bytes_buffered) {
93 const size_t bytes_to_read = size - bytes_buffered;
94 if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
95 istm_.read(&buf_[end_ptr_], bytes_to_read);
96 TREELITE_CHECK_EQ(istm_.gcount(), bytes_to_read)
97 <<
"Failed to peek " << size <<
" bytes";
98 end_ptr_ += bytes_to_read;
100 istm_.read(&buf_[end_ptr_], MAX_PEEK_WINDOW + 1 - end_ptr_);
101 size_t first_read = istm_.gcount();
102 istm_.read(&buf_[0], bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1);
103 size_t second_read = istm_.gcount();
104 TREELITE_CHECK_EQ(first_read + second_read, bytes_to_read)
105 <<
"Ill-formed XGBoost model: Failed to peek " << size <<
" bytes";
106 end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
110 if (begin_ptr_ <= end_ptr_) {
111 std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
113 std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
114 std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
122 std::vector<char> buf_;
123 size_t begin_ptr_, end_ptr_;
125 inline size_t BytesBuffered() {
126 if (begin_ptr_ <= end_ptr_) {
127 return end_ptr_ - begin_ptr_;
129 return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
134 template <
typename T>
135 inline void CONSUME_BYTES(
const T& fi,
size_t size) {
136 static std::vector<char> dummy(500);
137 if (size > dummy.size()) dummy.resize(size);
138 TREELITE_CHECK_EQ(fi->Read(&dummy[0], size), size)
139 <<
"Ill-formed XGBoost model format: cannot read " << size
140 <<
" bytes from the file";
143 struct LearnerModelParam {
144 bst_float base_score;
145 unsigned num_feature;
147 int contain_extra_attrs;
148 int contain_eval_metrics;
149 uint32_t major_version;
150 uint32_t minor_version;
153 static_assert(
sizeof(LearnerModelParam) == 136,
"This is the size defined in XGBoost.");
155 struct GBTreeModelParam {
161 int num_output_group;
162 int size_leaf_vector;
172 int size_leaf_vector;
179 bst_float base_weight;
187 Node() : sindex_(0) {
189 static_assert(
sizeof(Node) == 4 *
sizeof(
int) +
sizeof(Info),
190 "Node: 64 bit align");
192 inline int cleft()
const {
195 inline int cright()
const {
196 return this->cright_;
198 inline int cdefault()
const {
199 return this->default_left() ? this->cleft() : this->cright();
201 inline unsigned split_index()
const {
202 return sindex_ & ((1U << 31) - 1U);
204 inline bool default_left()
const {
205 return (sindex_ >> 31) != 0;
207 inline bool is_leaf()
const {
210 inline bst_float leaf_value()
const {
211 return (this->info_).leaf_value;
213 inline bst_float split_cond()
const {
214 return (this->info_).split_cond;
216 inline int parent()
const {
217 return parent_ & ((1U << 31) - 1);
219 inline bool is_root()
const {
220 return parent_ == -1;
222 inline void set_leaf(bst_float value) {
223 (this->info_).leaf_value = value;
227 inline void set_split(
unsigned split_index,
228 bst_float split_cond,
229 bool default_left =
false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
236 friend class XGBTree;
238 bst_float leaf_value;
239 bst_float split_cond;
246 inline bool is_deleted()
const {
247 return sindex_ == std::numeric_limits<unsigned>::max();
249 inline void set_parent(
int pidx,
bool is_left_child =
true) {
250 if (is_left_child) pidx |= (1U << 31);
251 this->parent_ = pidx;
257 std::vector<Node> nodes;
258 std::vector<NodeStat> stats;
260 inline int AllocNode() {
261 int nd = param.num_nodes++;
262 TREELITE_CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
263 <<
"number of nodes in the tree exceed 2^31";
264 nodes.resize(param.num_nodes);
270 inline Node& operator[](
int nid) {
274 inline const Node& operator[](
int nid)
const {
278 inline NodeStat& Stat(
int nid) {
282 inline const NodeStat& Stat(
int nid)
const {
288 nodes[0].set_leaf(0.0f);
289 nodes[0].set_parent(-1);
291 inline void AddChilds(
int nid) {
292 int pleft = this->AllocNode();
293 int pright = this->AllocNode();
294 nodes[nid].cleft_ = pleft;
295 nodes[nid].cright_ = pright;
296 nodes[nodes[nid].cleft() ].set_parent(nid,
true);
297 nodes[nodes[nid].cright()].set_parent(nid,
false);
299 inline void Load(PeekableInputStream* fi) {
300 TREELITE_CHECK_EQ(fi->Read(¶m,
sizeof(TreeParam)),
sizeof(TreeParam))
301 <<
"Ill-formed XGBoost model file: can't read TreeParam";
302 TREELITE_CHECK_GT(param.num_nodes, 0)
303 <<
"Ill-formed XGBoost model file: a tree can't be empty";
304 nodes.resize(param.num_nodes);
305 stats.resize(param.num_nodes);
306 TREELITE_CHECK_EQ(fi->Read(nodes.data(),
sizeof(Node) * nodes.size()),
307 sizeof(Node) * nodes.size())
308 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
309 TREELITE_CHECK_EQ(fi->Read(stats.data(),
sizeof(NodeStat) * stats.size()),
310 sizeof(NodeStat) * stats.size())
311 <<
"Ill-formed XGBoost model file: cannot read specified number of nodes";
312 if (param.size_leaf_vector != 0) {
314 TREELITE_CHECK_EQ(fi->Read(&len,
sizeof(len)),
sizeof(len))
315 <<
"Ill-formed XGBoost model file";
317 CONSUME_BYTES(fi,
sizeof(bst_float) * len);
320 TREELITE_CHECK_EQ(param.num_roots, 1)
321 <<
"Invalid XGBoost model file: treelite does not support trees " 322 <<
"with multiple roots";
326 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
327 std::vector<XGBTree> xgb_trees_;
328 LearnerModelParam mparam_;
329 GBTreeModelParam gbm_param_;
330 std::string name_gbm_;
331 std::string name_obj_;
334 std::unique_ptr<PeekableInputStream> fp(
new PeekableInputStream(fi));
338 if (fp->PeekRead(&header[0], 4) == 4) {
339 TREELITE_CHECK_NE(header,
"bs64")
340 <<
"Ill-formed XGBoost model file: Base64 format no longer supported";
341 if (header ==
"binf") {
342 CONSUME_BYTES(fp, 4);
346 TREELITE_CHECK_EQ(fp->Read(&mparam_,
sizeof(mparam_)),
sizeof(mparam_))
347 <<
"Ill-formed XGBoost model file: corrupted header";
350 TREELITE_CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
351 <<
"Ill-formed XGBoost model file: corrupted header";
353 name_obj_.resize(len);
354 TREELITE_CHECK_EQ(fp->Read(&name_obj_[0], len), len)
355 <<
"Ill-formed XGBoost model file: corrupted header";
361 TREELITE_CHECK_EQ(fp->Read(&len,
sizeof(len)),
sizeof(len))
362 <<
"Ill-formed XGBoost model file: corrupted header";
363 name_gbm_.resize(len);
365 TREELITE_CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
366 <<
"Ill-formed XGBoost model file: corrupted header";
371 TREELITE_CHECK(name_gbm_ ==
"gbtree" || name_gbm_ ==
"dart")
372 <<
"Invalid XGBoost model file: " 373 <<
"Gradient booster must be gbtree or dart type.";
375 TREELITE_CHECK_EQ(fp->Read(&gbm_param_,
sizeof(gbm_param_)),
sizeof(gbm_param_))
376 <<
"Invalid XGBoost model file: corrupted GBTree parameters";
377 TREELITE_CHECK_GE(gbm_param_.num_trees, 0)
378 <<
"Invalid XGBoost model file: num_trees must be 0 or greater";
379 for (
int i = 0; i < gbm_param_.num_trees; ++i) {
380 xgb_trees_.emplace_back();
381 xgb_trees_.back().Load(fp.get());
383 if (mparam_.major_version < 1 || (mparam_.major_version == 1 && mparam_.minor_version < 6)) {
385 TREELITE_CHECK_EQ(gbm_param_.num_roots, 1) <<
"multi-root trees not supported";
387 std::vector<int> tree_info;
388 tree_info.resize(gbm_param_.num_trees);
389 if (gbm_param_.num_trees > 0) {
390 TREELITE_CHECK_EQ(fp->Read(tree_info.data(),
sizeof(int32_t) * tree_info.size()),
391 sizeof(int32_t) * tree_info.size());
394 std::vector<bst_float> weight_drop;
395 if (name_gbm_ ==
"dart") {
396 weight_drop.resize(gbm_param_.num_trees);
398 fi.read(reinterpret_cast<char*>(&sz),
sizeof(uint64_t));
399 TREELITE_CHECK_EQ(sz, gbm_param_.num_trees);
400 if (gbm_param_.num_trees != 0) {
401 for (uint64_t i = 0; i < sz; ++i) {
402 fi.read(reinterpret_cast<char*>(&weight_drop[i]),
sizeof(bst_float));
408 std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<float, float>();
410 model->
num_feature =
static_cast<int>(mparam_.num_feature);
411 model->average_tree_output =
false;
412 const int num_class = std::max(mparam_.num_class, 1);
415 model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
416 model->task_param.grove_per_class =
true;
419 model->task_type = treelite::TaskType::kBinaryClfRegr;
420 model->task_param.grove_per_class =
false;
422 model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
423 model->task_param.num_class = num_class;
424 model->task_param.leaf_vector_size = 1;
427 treelite::details::xgboost::SetPredTransform(name_obj_, &model->param);
430 model->param.global_bias =
static_cast<float>(mparam_.base_score);
433 const bool need_transform_to_margin = mparam_.major_version >= 1;
434 if (need_transform_to_margin) {
435 treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param);
439 for (
const auto& xgb_tree : xgb_trees_) {
440 model->trees.emplace_back();
447 std::queue<std::pair<int, int>> Q;
451 std::tie(old_id, new_id) = Q.front(); Q.pop();
452 const XGBTree::Node& node = xgb_tree[old_id];
453 const NodeStat stat = xgb_tree.Stat(old_id);
454 if (node.is_leaf()) {
455 bst_float leaf_value = node.leaf_value();
457 if (!weight_drop.empty()) {
458 leaf_value *= weight_drop[model->trees.size() - 1];
460 tree.
SetLeaf(new_id, static_cast<float>(leaf_value));
462 const bst_float split_cond = node.split_cond();
465 static_cast<float>(split_cond), node.default_left(), treelite::Operator::kLT);
466 tree.
SetGain(new_id, stat.loss_chg);
467 Q.push({node.cleft(), tree.
LeftChild(new_id)});
468 Q.push({node.cright(), tree.
RightChild(new_id)});
477 unsigned num_parallel_tree = 0;
478 for (
int e : tree_info) {
484 if (num_parallel_tree > 1) {
492 std::vector<treelite::Tree<float, float>> new_trees;
493 std::size_t num_tree = model->trees.size();
494 for (std::size_t c = 0; c < num_parallel_tree; ++c) {
495 for (std::size_t tree_id = c; tree_id < num_tree; tree_id += num_parallel_tree) {
496 new_trees.push_back(std::move(model->trees[tree_id]));
499 TREELITE_CHECK_EQ(new_trees.size(), num_tree);
500 model->trees = std::move(new_trees);
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
model structure for tree ensemble
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
logging facility for Treelite
int32_t num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
void SetGain(int nid, double gain)
set the gain value of the node
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
Helper functions for loading XGBoost models.
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node