treelite
xgboost.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <dmlc/memory_io.h>
10 #include <treelite/tree.h>
11 #include <memory>
12 #include <queue>
13 #include <cstring>
14 
15 namespace {
16 
17 treelite::Model ParseStream(dmlc::Stream* fi);
18 void SaveModelToStream(dmlc::Stream* fo, const treelite::Model& model,
19  const char* name_obj);
20 
21 } // namespace anonymous
22 
23 namespace treelite {
24 namespace frontend {
25 
26 DMLC_REGISTRY_FILE_TAG(xgboost);
27 
28 Model LoadXGBoostModel(const char* filename) {
29  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
30  return ParseStream(fi.get());
31 }
32 
33 void ExportXGBoostModel(const char* filename, const Model& model,
34  const char* name_obj) {
35  std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(filename, "w"));
36  SaveModelToStream(fo.get(), model, name_obj);
37 }
38 
39 Model LoadXGBoostModel(const void* buf, size_t len) {
40  dmlc::MemoryFixedSizeStream fs((void*)buf, len);
41  return ParseStream(&fs);
42 }
43 
44 } // namespace frontend
45 } // namespace treelite
46 
47 /* auxiliary data structures to interpret xgboost model file */
48 namespace {
49 
50 typedef float bst_float;
51 
52 /* peekable input stream implemented with a ring buffer */
53 class PeekableInputStream {
54  public:
55  const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes
56 
57  PeekableInputStream(dmlc::Stream* fi)
58  : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
59 
60  inline size_t Read(void* ptr, size_t size) {
61  const size_t bytes_buffered = BytesBuffered();
62  char* cptr = static_cast<char*>(ptr);
63  if (size <= bytes_buffered) {
64  // all content already buffered; consume buffer
65  if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
66  std::memcpy(cptr, &buf_[begin_ptr_], size);
67  begin_ptr_ += size;
68  } else {
69  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
70  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
71  size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
72  begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
73  }
74  return size;
75  } else { // consume buffer entirely and read more bytes
76  const size_t bytes_to_read = size - bytes_buffered;
77  if (begin_ptr_ <= end_ptr_) {
78  std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
79  } else {
80  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
81  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
82  bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
83  }
84  begin_ptr_ = end_ptr_;
85  return bytes_buffered
86  + istm_->Read(cptr + bytes_buffered, bytes_to_read);
87  }
88  }
89 
90  inline size_t PeekRead(void* ptr, size_t size) {
91  CHECK_LE(size, MAX_PEEK_WINDOW)
92  << "PeekableInputStream allows peeking up to "
93  << MAX_PEEK_WINDOW << " bytes";
94  char* cptr = static_cast<char*>(ptr);
95  const size_t bytes_buffered = BytesBuffered();
96  /* fill buffer with additional bytes, up to size */
97  if (size > bytes_buffered) {
98  const size_t bytes_to_read = size - bytes_buffered;
99  if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
100  CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
101  << "Failed to peek " << size << " bytes";
102  end_ptr_ += bytes_to_read;
103  } else {
104  CHECK_EQ( istm_->Read(&buf_[end_ptr_],
105  MAX_PEEK_WINDOW + 1 - end_ptr_)
106  + istm_->Read(&buf_[0],
107  bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
108  bytes_to_read)
109  << "Ill-formed XGBoost model: Failed to peek " << size << " bytes";
110  end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
111  }
112  }
113  /* copy buffer into ptr without emptying buffer */
114  if (begin_ptr_ <= end_ptr_) { // usual case
115  std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
116  } else { // context wrapped around the end
117  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
118  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
119  }
120 
121  return size;
122  }
123 
124  private:
125  dmlc::Stream* istm_;
126  std::vector<char> buf_;
127  size_t begin_ptr_, end_ptr_;
128 
129  inline size_t BytesBuffered() {
130  if (begin_ptr_ <= end_ptr_) { // usual case
131  return end_ptr_ - begin_ptr_;
132  } else { // context wrapped around the end
133  return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
134  }
135  }
136 };
137 
138 template <typename T>
139 inline void CONSUME_BYTES(const T& fi, size_t size) {
140  static std::vector<char> dummy(500);
141  if (size > dummy.size()) dummy.resize(size);
142  CHECK_EQ(fi->Read(&dummy[0], size), size)
143  << "Ill-formed XGBoost model format: cannot read " << size
144  << " bytes from the file";
145 }
146 
147 struct LearnerModelParam {
148  bst_float base_score; // global bias
149  unsigned num_feature;
150  int num_class;
151  int contain_extra_attrs;
152  int contain_eval_metrics;
153  int pad2[29];
154 };
155 
156 struct GBTreeModelParam {
157  int num_trees;
158  int num_roots;
159  int num_feature;
160  int pad1;
161  int64_t pad2;
162  int num_output_group;
163  int size_leaf_vector;
164  int pad3[32];
165 };
166 
167 struct TreeParam {
168  int num_roots;
169  int num_nodes;
170  int num_deleted;
171  int max_depth;
172  int num_feature;
173  int size_leaf_vector;
174  int reserved[31];
175 };
176 
177 struct NodeStat {
178  bst_float loss_chg;
179  bst_float sum_hess;
180  bst_float base_weight;
181  int leaf_child_cnt;
182 };
183 
184 class XGBTree {
185  public:
186  class Node {
187  public:
188  Node() : sindex_(0) {
189  // assert compact alignment
190  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
191  "Node: 64 bit align");
192  }
193  inline int cleft() const {
194  return this->cleft_;
195  }
196  inline int cright() const {
197  return this->cright_;
198  }
199  inline int cdefault() const {
200  return this->default_left() ? this->cleft() : this->cright();
201  }
202  inline unsigned split_index() const {
203  return sindex_ & ((1U << 31) - 1U);
204  }
205  inline bool default_left() const {
206  return (sindex_ >> 31) != 0;
207  }
208  inline bool is_leaf() const {
209  return cleft_ == -1;
210  }
211  inline bst_float leaf_value() const {
212  return (this->info_).leaf_value;
213  }
214  inline bst_float split_cond() const {
215  return (this->info_).split_cond;
216  }
217  inline int parent() const {
218  return parent_ & ((1U << 31) - 1);
219  }
220  inline bool is_root() const {
221  return parent_ == -1;
222  }
223  inline void set_leaf(bst_float value) {
224  (this->info_).leaf_value = value;
225  this->cleft_ = -1;
226  this->cright_ = -1;
227  }
228  inline void set_split(unsigned split_index,
229  bst_float split_cond,
230  bool default_left = false) {
231  if (default_left) split_index |= (1U << 31);
232  this->sindex_ = split_index;
233  (this->info_).split_cond = split_cond;
234  }
235 
236  private:
237  friend class XGBTree;
238  union Info {
239  bst_float leaf_value;
240  bst_float split_cond;
241  };
242  int parent_;
243  int cleft_, cright_;
244  unsigned sindex_;
245  Info info_;
246 
247  inline bool is_deleted() const {
248  return sindex_ == std::numeric_limits<unsigned>::max();
249  }
250  inline void set_parent(int pidx, bool is_left_child = true) {
251  if (is_left_child) pidx |= (1U << 31);
252  this->parent_ = pidx;
253  }
254  };
255 
256  private:
257  TreeParam param;
258  std::vector<Node> nodes;
259 
260  inline int AllocNode() {
261  int nd = param.num_nodes++;
262  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
263  << "number of nodes in the tree exceed 2^31";
264  nodes.resize(param.num_nodes);
265  return nd;
266  }
267 
268  public:
269  inline Node& operator[](int nid) {
270  return nodes[nid];
271  }
272  inline const Node& operator[](int nid) const {
273  return nodes[nid];
274  }
275  inline void Init() {
276  param.num_nodes = 1;
277  nodes.resize(1);
278  nodes[0].set_leaf(0.0f);
279  nodes[0].set_parent(-1);
280  }
281  inline void AddChilds(int nid) {
282  int pleft = this->AllocNode();
283  int pright = this->AllocNode();
284  nodes[nid].cleft_ = pleft;
285  nodes[nid].cright_ = pright;
286  nodes[nodes[nid].cleft() ].set_parent(nid, true);
287  nodes[nodes[nid].cright()].set_parent(nid, false);
288  }
289  inline void Load(PeekableInputStream* fi) {
290  CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
291  << "Ill-formed XGBoost model file: can't read TreeParam";
292  nodes.resize(param.num_nodes);
293  CHECK_NE(param.num_nodes, 0)
294  << "Ill-formed XGBoost model file: a tree can't be empty";
295  CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()),
296  sizeof(Node) * nodes.size())
297  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
298  CONSUME_BYTES(fi, (3 * sizeof(bst_float) + sizeof(int)) * param.num_nodes);
299  if (param.size_leaf_vector != 0) {
300  uint64_t len;
301  CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
302  << "Ill-formed XGBoost model file";
303  if (len > 0) {
304  CONSUME_BYTES(fi, sizeof(bst_float) * len);
305  }
306  }
307  CHECK_EQ(param.num_roots, 1)
308  << "Invalid XGBoost model file: treelite does not support trees "
309  << "with multiple roots";
310  }
311  inline void Save(dmlc::Stream* fo, int num_feature) const {
312  TreeParam param_;
313  const bst_float nan = std::numeric_limits<bst_float>::quiet_NaN();
314  std::vector<NodeStat> stats_(nodes.size(), NodeStat{nan, nan, nan, -1});
315  param_.num_roots = 1;
316  param_.num_nodes = static_cast<int>(nodes.size());
317  param_.num_deleted = 0;
318  std::function<int(int)> max_depth_func;
319  max_depth_func = [&max_depth_func, this](int nid) -> int {
320  if (nodes[nid].is_leaf()) {
321  return 0;
322  } else {
323  return 1 + std::max(max_depth_func(nodes[nid].cleft()),
324  max_depth_func(nodes[nid].cright()));
325  }
326  };
327  param_.max_depth = max_depth_func(0);
328  param_.num_feature = num_feature;
329  param_.size_leaf_vector = 0;
330  fo->Write(&param_, sizeof(TreeParam));
331  fo->Write(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size());
332  // write dummy stats
333  fo->Write(dmlc::BeginPtr(stats_), sizeof(NodeStat) * nodes.size());
334  }
335 };
336 
337 inline treelite::Model ParseStream(dmlc::Stream* fi) {
338  std::vector<XGBTree> xgb_trees_;
339  LearnerModelParam mparam_; // model parameter
340  GBTreeModelParam gbm_param_; // GBTree training parameter
341  std::string name_gbm_;
342  std::string name_obj_;
343 
344  /* 1. Parse input stream */
345  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
346  // backward compatible header check.
347  std::string header;
348  header.resize(4);
349  if (fp->PeekRead(&header[0], 4) == 4) {
350  CHECK_NE(header, "bs64")
351  << "Ill-formed XGBoost model file: Base64 format no longer supported";
352  if (header == "binf") {
353  CONSUME_BYTES(fp, 4);
354  }
355  }
356  // read parameter
357  CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
358  << "Ill-formed XGBoost model file: corrupted header";
359  LOG(INFO) << "Global bias of the model: " << mparam_.base_score;
360  {
361  // backward compatibility code for compatible with old model type
362  // for new model, Read(&name_obj_) is suffice
363  uint64_t len;
364  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
365  << "Ill-formed XGBoost model file: corrupted header";
366  if (len >= std::numeric_limits<unsigned>::max()) {
367  int gap;
368  CHECK_EQ(fp->Read(&gap, sizeof(gap)), sizeof(gap))
369  << "Ill-formed XGBoost model file: corrupted header";
370  len = len >> static_cast<uint64_t>(32UL);
371  }
372  if (len != 0) {
373  name_obj_.resize(len);
374  CHECK_EQ(fp->Read(&name_obj_[0], len), len)
375  << "Ill-formed XGBoost model file: corrupted header";
376  }
377  }
378 
379  {
380  uint64_t len;
381  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
382  << "Ill-formed XGBoost model file: corrupted header";
383  name_gbm_.resize(len);
384  if (len > 0) {
385  CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
386  << "Ill-formed XGBoost model file: corrupted header";
387  }
388  }
389 
390  /* loading GBTree */
391  CHECK_EQ(name_gbm_, "gbtree")
392  << "Invalid XGBoost model file: "
393  << "Gradient booster must be gbtree type.";
394 
395  CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
396  << "Invalid XGBoost model file: corrupted GBTree parameters";
397  LOG(INFO) << "gbm_param_.num_feature = " << gbm_param_.num_feature;
398  LOG(INFO) << "gbm_param_.num_output_group = " << gbm_param_.num_output_group;
399  for (int i = 0; i < gbm_param_.num_trees; ++i) {
400  xgb_trees_.emplace_back();
401  xgb_trees_.back().Load(fp.get());
402  }
403  CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";
404 
405  /* 2. Export model */
406  treelite::Model model;
407  model.num_feature = gbm_param_.num_feature;
408  model.num_output_group = gbm_param_.num_output_group;
409  model.random_forest_flag = false;
410 
411  // set global bias
412  model.param.global_bias = static_cast<float>(mparam_.base_score);
413 
414  // set correct prediction transform function, depending on objective function
415  if (name_obj_ == "multi:softmax") {
416  model.param.pred_transform = "max_index";
417  } else if (name_obj_ == "multi:softprob") {
418  model.param.pred_transform = "softmax";
419  } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
420  model.param.pred_transform = "sigmoid";
421  model.param.sigmoid_alpha = 1.0f;
422  } else if (name_obj_ == "count:poisson" || name_obj_ == "reg:gamma"
423  || name_obj_ == "reg:tweedie") {
424  model.param.pred_transform = "exponential";
425  } else {
426  model.param.pred_transform = "identity";
427  }
428 
429  // traverse trees
430  for (const auto& xgb_tree : xgb_trees_) {
431  model.trees.emplace_back();
432  treelite::Tree& tree = model.trees.back();
433  tree.Init();
434 
435  // assign node ID's so that a breadth-wise traversal would yield
436  // the monotonic sequence 0, 1, 2, ...
437  // deleted nodes will be excluded
438  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
439  Q.push({0, 0});
440  while (!Q.empty()) {
441  int old_id, new_id;
442  std::tie(old_id, new_id) = Q.front(); Q.pop();
443  const XGBTree::Node& node = xgb_tree[old_id];
444  if (node.is_leaf()) {
445  const bst_float leaf_value = node.leaf_value();
446  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
447  } else {
448  const bst_float split_cond = node.split_cond();
449  tree.AddChilds(new_id);
450  tree[new_id].set_numerical_split(node.split_index(),
451  static_cast<treelite::tl_float>(split_cond),
452  node.default_left(),
453  treelite::Operator::kLT);
454  Q.push({node.cleft(), tree[new_id].cleft()});
455  Q.push({node.cright(), tree[new_id].cright()});
456  }
457  }
458  }
459  return model;
460 }
461 
462 inline void SaveModelToStream(dmlc::Stream* fo, const treelite::Model& model,
463  const char* name_obj) {
464  LearnerModelParam mparam_;
465  GBTreeModelParam gbm_param_;
466  /* Learner parameters */
467  mparam_.base_score = model.param.global_bias;
468  mparam_.num_feature = model.num_feature;
469  mparam_.num_class = model.num_output_group;
470  mparam_.contain_extra_attrs = 0;
471  mparam_.contain_eval_metrics = 0;
472  fo->Write(&mparam_, sizeof(LearnerModelParam));
473  /* name of objective and gbm class */
474  const std::string name_gbm_ = "gbtree";
475  fo->Write(std::string(name_obj));
476  fo->Write(name_gbm_);
477  /* GBTree parameters */
478  gbm_param_.num_trees = model.trees.size();
479  gbm_param_.num_roots = 1;
480  gbm_param_.num_feature = model.num_feature;
481  gbm_param_.num_output_group = model.num_output_group;
482  gbm_param_.size_leaf_vector = 0;
483  fo->Write(&gbm_param_, sizeof(gbm_param_));
484  /* Individual decision trees */
485  for (const treelite::Tree& tree : model.trees) {
486  XGBTree xgb_tree_;
487  xgb_tree_.Init();
488  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
489  Q.push({0, 0});
490  while (!Q.empty()) {
491  int old_id, new_id;
492  std::tie(old_id, new_id) = Q.front(); Q.pop();
493  const treelite::Tree::Node& node = tree[old_id];
494  if (node.is_leaf()) {
495  const treelite::tl_float leaf_value = node.leaf_value();
496  xgb_tree_[new_id].set_leaf(static_cast<bst_float>(leaf_value));
497  } else {
498  const treelite::tl_float split_cond = node.threshold();
499  xgb_tree_.AddChilds(new_id);
500  CHECK(node.comparison_op() == treelite::Operator::kLT)
501  << "Comparison operator must be `<`";
502  xgb_tree_[new_id].set_split(node.split_index(),
503  static_cast<bst_float>(split_cond),
504  node.default_left());
505  Q.push({node.cleft(), xgb_tree_[new_id].cleft()});
506  Q.push({node.cright(), xgb_tree_[new_id].cright()});
507  }
508  }
509  xgb_tree_.Save(fo, model.num_feature);
510  }
511  // write dummy tree_info
512  std::vector<int> tree_info_(model.trees.size(), 0);
513  if (model.num_output_group > 1) {
514  for (size_t i = 0; i < model.trees.size(); ++i) {
515  tree_info_[i] = i % model.num_output_group;
516  }
517  }
518  fo->Write(dmlc::BeginPtr(tree_info_), sizeof(int) * tree_info_.size());
519 }
520 
521 } // namespace anonymous
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:361
void Init()
initialize the model with a single root node
Definition: tree.h:235
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
tree node
Definition: tree.h:22
std::vector< Tree > trees
member trees
Definition: tree.h:353
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:315
unsigned split_index() const
feature index of split condition
Definition: tree.h:38
ModelParam param
extra parameters
Definition: tree.h:366
model structure for tree
Operator comparison_op() const
get comparison operator
Definition: tree.h:83
in-memory representation of a decision tree
Definition: tree.h:19
float global_bias
global bias of the model
Definition: tree.h:322
tl_float threshold() const
Definition: tree.h:67
int cright() const
index of right child
Definition: tree.h:30
std::string pred_transform
name of prediction transform function
Definition: tree.h:307
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:28
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:364
tl_float leaf_value() const
Definition: tree.h:50
bool default_left() const
when feature is unknown, whether goes to left child
Definition: tree.h:42
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:245
int cleft() const
index of left child
Definition: tree.h:26
void ExportXGBoostModel(const char *filename, const Model &model, const char *name_obj)
export a model in XGBoost format. The exported model can be read by XGBoost (dmlc/xgboost).
Definition: xgboost.cc:33
bool is_leaf() const
whether current node is leaf node
Definition: tree.h:46
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:358