Treelite
xgboost.cc
Go to the documentation of this file.
1 
8 #include "xgboost/xgboost.h"
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <treelite/logging.h>
12 #include <algorithm>
13 #include <memory>
14 #include <queue>
15 #include <fstream>
16 #include <sstream>
17 #include <cstring>
18 #include <cstdint>
19 
20 namespace {
21 
22 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
23 
24 } // anonymous namespace
25 
26 namespace treelite {
27 namespace frontend {
28 
29 std::unique_ptr<treelite::Model> LoadXGBoostModel(const char* filename) {
30  std::ifstream fi(filename, std::ios::in | std::ios::binary);
31  return ParseStream(fi);
32 }
33 
34 std::unique_ptr<treelite::Model> LoadXGBoostModel(const void* buf, size_t len) {
35  std::istringstream fi(std::string(static_cast<const char*>(buf), len));
36  return ParseStream(fi);
37 }
38 
39 } // namespace frontend
40 } // namespace treelite
41 
42 /* auxiliary data structures to interpret xgboost model file */
43 namespace {
44 
45 typedef float bst_float;
46 
47 /* peekable input stream implemented with a ring buffer */
48 class PeekableInputStream {
49  public:
50  const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes
51 
52  explicit PeekableInputStream(std::istream& fi)
53  : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
54 
55  inline size_t Read(void* ptr, size_t size) {
56  const size_t bytes_buffered = BytesBuffered();
57  char* cptr = static_cast<char*>(ptr);
58  if (size <= bytes_buffered) {
59  // all content already buffered; consume buffer
60  if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
61  std::memcpy(cptr, &buf_[begin_ptr_], size);
62  begin_ptr_ += size;
63  } else {
64  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
65  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
66  size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
67  begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
68  }
69  return size;
70  } else { // consume buffer entirely and read more bytes
71  const size_t bytes_to_read = size - bytes_buffered;
72  if (begin_ptr_ <= end_ptr_) {
73  std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
74  } else {
75  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
76  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
77  bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
78  }
79  begin_ptr_ = end_ptr_;
80  istm_.read(cptr + bytes_buffered, bytes_to_read);
81  return bytes_buffered + istm_.gcount();
82  }
83  }
84 
85  inline size_t PeekRead(void* ptr, size_t size) {
86  TREELITE_CHECK_LE(size, MAX_PEEK_WINDOW)
87  << "PeekableInputStream allows peeking up to "
88  << MAX_PEEK_WINDOW << " bytes";
89  char* cptr = static_cast<char*>(ptr);
90  const size_t bytes_buffered = BytesBuffered();
91  /* fill buffer with additional bytes, up to size */
92  if (size > bytes_buffered) {
93  const size_t bytes_to_read = size - bytes_buffered;
94  if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
95  istm_.read(&buf_[end_ptr_], bytes_to_read);
96  TREELITE_CHECK_EQ(istm_.gcount(), bytes_to_read)
97  << "Failed to peek " << size << " bytes";
98  end_ptr_ += bytes_to_read;
99  } else {
100  istm_.read(&buf_[end_ptr_], MAX_PEEK_WINDOW + 1 - end_ptr_);
101  size_t first_read = istm_.gcount();
102  istm_.read(&buf_[0], bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1);
103  size_t second_read = istm_.gcount();
104  TREELITE_CHECK_EQ(first_read + second_read, bytes_to_read)
105  << "Ill-formed XGBoost model: Failed to peek " << size << " bytes";
106  end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
107  }
108  }
109  /* copy buffer into ptr without emptying buffer */
110  if (begin_ptr_ <= end_ptr_) { // usual case
111  std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
112  } else { // context wrapped around the end
113  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
114  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
115  }
116 
117  return size;
118  }
119 
120  private:
121  std::istream& istm_;
122  std::vector<char> buf_;
123  size_t begin_ptr_, end_ptr_;
124 
125  inline size_t BytesBuffered() {
126  if (begin_ptr_ <= end_ptr_) { // usual case
127  return end_ptr_ - begin_ptr_;
128  } else { // context wrapped around the end
129  return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
130  }
131  }
132 };
133 
134 template <typename T>
135 inline void CONSUME_BYTES(const T& fi, size_t size) {
136  static std::vector<char> dummy(500);
137  if (size > dummy.size()) dummy.resize(size);
138  TREELITE_CHECK_EQ(fi->Read(&dummy[0], size), size)
139  << "Ill-formed XGBoost model format: cannot read " << size
140  << " bytes from the file";
141 }
142 
143 struct LearnerModelParam {
144  bst_float base_score; // global bias
145  unsigned num_feature;
146  int num_class;
147  int contain_extra_attrs;
148  int contain_eval_metrics;
149  uint32_t major_version;
150  uint32_t minor_version;
151  int pad2[27];
152 };
153 static_assert(sizeof(LearnerModelParam) == 136, "This is the size defined in XGBoost.");
154 
155 struct GBTreeModelParam {
156  int num_trees;
157  int num_roots;
158  int num_feature;
159  int pad1;
160  int64_t pad2;
161  int num_output_group;
162  int size_leaf_vector;
163  int pad3[32];
164 };
165 
166 struct TreeParam {
167  int num_roots;
168  int num_nodes;
169  int num_deleted;
170  int max_depth;
171  int num_feature;
172  int size_leaf_vector;
173  int reserved[31];
174 };
175 
176 struct NodeStat {
177  bst_float loss_chg;
178  bst_float sum_hess;
179  bst_float base_weight;
180  int leaf_child_cnt;
181 };
182 
183 class XGBTree {
184  public:
185  class Node {
186  public:
187  Node() : sindex_(0) {
188  // assert compact alignment
189  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
190  "Node: 64 bit align");
191  }
192  inline int cleft() const {
193  return this->cleft_;
194  }
195  inline int cright() const {
196  return this->cright_;
197  }
198  inline int cdefault() const {
199  return this->default_left() ? this->cleft() : this->cright();
200  }
201  inline unsigned split_index() const {
202  return sindex_ & ((1U << 31) - 1U);
203  }
204  inline bool default_left() const {
205  return (sindex_ >> 31) != 0;
206  }
207  inline bool is_leaf() const {
208  return cleft_ == -1;
209  }
210  inline bst_float leaf_value() const {
211  return (this->info_).leaf_value;
212  }
213  inline bst_float split_cond() const {
214  return (this->info_).split_cond;
215  }
216  inline int parent() const {
217  return parent_ & ((1U << 31) - 1);
218  }
219  inline bool is_root() const {
220  return parent_ == -1;
221  }
222  inline void set_leaf(bst_float value) {
223  (this->info_).leaf_value = value;
224  this->cleft_ = -1;
225  this->cright_ = -1;
226  }
227  inline void set_split(unsigned split_index,
228  bst_float split_cond,
229  bool default_left = false) {
230  if (default_left) split_index |= (1U << 31);
231  this->sindex_ = split_index;
232  (this->info_).split_cond = split_cond;
233  }
234 
235  private:
236  friend class XGBTree;
237  union Info {
238  bst_float leaf_value;
239  bst_float split_cond;
240  };
241  int parent_;
242  int cleft_, cright_;
243  unsigned sindex_;
244  Info info_;
245 
246  inline bool is_deleted() const {
247  return sindex_ == std::numeric_limits<unsigned>::max();
248  }
249  inline void set_parent(int pidx, bool is_left_child = true) {
250  if (is_left_child) pidx |= (1U << 31);
251  this->parent_ = pidx;
252  }
253  };
254 
255  private:
256  TreeParam param;
257  std::vector<Node> nodes;
258  std::vector<NodeStat> stats;
259 
260  inline int AllocNode() {
261  int nd = param.num_nodes++;
262  TREELITE_CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
263  << "number of nodes in the tree exceed 2^31";
264  nodes.resize(param.num_nodes);
265  return nd;
266  }
267 
268  public:
270  inline Node& operator[](int nid) {
271  return nodes[nid];
272  }
274  inline const Node& operator[](int nid) const {
275  return nodes[nid];
276  }
278  inline NodeStat& Stat(int nid) {
279  return stats[nid];
280  }
282  inline const NodeStat& Stat(int nid) const {
283  return stats[nid];
284  }
285  inline void Init() {
286  param.num_nodes = 1;
287  nodes.resize(1);
288  nodes[0].set_leaf(0.0f);
289  nodes[0].set_parent(-1);
290  }
291  inline void AddChilds(int nid) {
292  int pleft = this->AllocNode();
293  int pright = this->AllocNode();
294  nodes[nid].cleft_ = pleft;
295  nodes[nid].cright_ = pright;
296  nodes[nodes[nid].cleft() ].set_parent(nid, true);
297  nodes[nodes[nid].cright()].set_parent(nid, false);
298  }
299  inline void Load(PeekableInputStream* fi) {
300  TREELITE_CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
301  << "Ill-formed XGBoost model file: can't read TreeParam";
302  TREELITE_CHECK_GT(param.num_nodes, 0)
303  << "Ill-formed XGBoost model file: a tree can't be empty";
304  nodes.resize(param.num_nodes);
305  stats.resize(param.num_nodes);
306  TREELITE_CHECK_EQ(fi->Read(nodes.data(), sizeof(Node) * nodes.size()),
307  sizeof(Node) * nodes.size())
308  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
309  TREELITE_CHECK_EQ(fi->Read(stats.data(), sizeof(NodeStat) * stats.size()),
310  sizeof(NodeStat) * stats.size())
311  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
312  if (param.size_leaf_vector != 0) {
313  uint64_t len;
314  TREELITE_CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
315  << "Ill-formed XGBoost model file";
316  if (len > 0) {
317  CONSUME_BYTES(fi, sizeof(bst_float) * len);
318  }
319  }
320  TREELITE_CHECK_EQ(param.num_roots, 1)
321  << "Invalid XGBoost model file: treelite does not support trees "
322  << "with multiple roots";
323  }
324 };
325 
326 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
327  std::vector<XGBTree> xgb_trees_;
328  LearnerModelParam mparam_; // model parameter
329  GBTreeModelParam gbm_param_; // GBTree training parameter
330  std::string name_gbm_;
331  std::string name_obj_;
332 
333  /* 1. Parse input stream */
334  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
335  // backward compatible header check.
336  std::string header;
337  header.resize(4);
338  if (fp->PeekRead(&header[0], 4) == 4) {
339  TREELITE_CHECK_NE(header, "bs64")
340  << "Ill-formed XGBoost model file: Base64 format no longer supported";
341  if (header == "binf") {
342  CONSUME_BYTES(fp, 4);
343  }
344  }
345  // read parameter
346  TREELITE_CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
347  << "Ill-formed XGBoost model file: corrupted header";
348  {
349  uint64_t len;
350  TREELITE_CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
351  << "Ill-formed XGBoost model file: corrupted header";
352  if (len != 0) {
353  name_obj_.resize(len);
354  TREELITE_CHECK_EQ(fp->Read(&name_obj_[0], len), len)
355  << "Ill-formed XGBoost model file: corrupted header";
356  }
357  }
358 
359  {
360  uint64_t len;
361  TREELITE_CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
362  << "Ill-formed XGBoost model file: corrupted header";
363  name_gbm_.resize(len);
364  if (len > 0) {
365  TREELITE_CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
366  << "Ill-formed XGBoost model file: corrupted header";
367  }
368  }
369 
370  /* loading GBTree */
371  TREELITE_CHECK(name_gbm_ == "gbtree" || name_gbm_ == "dart")
372  << "Invalid XGBoost model file: "
373  << "Gradient booster must be gbtree or dart type.";
374 
375  TREELITE_CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
376  << "Invalid XGBoost model file: corrupted GBTree parameters";
377  TREELITE_CHECK_GE(gbm_param_.num_trees, 0)
378  << "Invalid XGBoost model file: num_trees must be 0 or greater";
379  for (int i = 0; i < gbm_param_.num_trees; ++i) {
380  xgb_trees_.emplace_back();
381  xgb_trees_.back().Load(fp.get());
382  }
383  TREELITE_CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";
384  // tree_info is currently unused.
385  std::vector<int> tree_info;
386  tree_info.resize(gbm_param_.num_trees);
387  if (gbm_param_.num_trees > 0) {
388  TREELITE_CHECK_EQ(fp->Read(tree_info.data(), sizeof(int32_t) * tree_info.size()),
389  sizeof(int32_t) * tree_info.size());
390  }
391  // Load weight drop values (per tree) for dart models.
392  std::vector<bst_float> weight_drop;
393  if (name_gbm_ == "dart") {
394  weight_drop.resize(gbm_param_.num_trees);
395  uint64_t sz;
396  fi.read(reinterpret_cast<char*>(&sz), sizeof(uint64_t));
397  TREELITE_CHECK_EQ(sz, gbm_param_.num_trees);
398  if (gbm_param_.num_trees != 0) {
399  for (uint64_t i = 0; i < sz; ++i) {
400  fi.read(reinterpret_cast<char*>(&weight_drop[i]), sizeof(bst_float));
401  }
402  }
403  }
404 
405  /* 2. Export model */
406  std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<float, float>();
407  auto* model = dynamic_cast<treelite::ModelImpl<float, float>*>(model_ptr.get());
408  model->num_feature = static_cast<int>(mparam_.num_feature);
409  model->average_tree_output = false;
410  const int num_class = std::max(mparam_.num_class, 1);
411  if (num_class > 1) {
412  // multi-class classifier
413  model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
414  model->task_param.grove_per_class = true;
415  } else {
416  // binary classifier or regressor
417  model->task_type = treelite::TaskType::kBinaryClfRegr;
418  model->task_param.grove_per_class = false;
419  }
420  model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
421  model->task_param.num_class = num_class;
422  model->task_param.leaf_vector_size = 1;
423 
424  // set correct prediction transform function, depending on objective function
425  treelite::details::xgboost::SetPredTransform(name_obj_, &model->param);
426 
427  // set global bias
428  model->param.global_bias = static_cast<float>(mparam_.base_score);
429  // Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
430  // 1.0 it's the original value provided by user.
431  const bool need_transform_to_margin = mparam_.major_version >= 1;
432  if (need_transform_to_margin) {
433  treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param);
434  }
435 
436  // traverse trees
437  for (const auto& xgb_tree : xgb_trees_) {
438  model->trees.emplace_back();
439  treelite::Tree<float, float>& tree = model->trees.back();
440  tree.Init();
441 
442  // assign node ID's so that a breadth-wise traversal would yield
443  // the monotonic sequence 0, 1, 2, ...
444  // deleted nodes will be excluded
445  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
446  Q.push({0, 0});
447  while (!Q.empty()) {
448  int old_id, new_id;
449  std::tie(old_id, new_id) = Q.front(); Q.pop();
450  const XGBTree::Node& node = xgb_tree[old_id];
451  const NodeStat stat = xgb_tree.Stat(old_id);
452  if (node.is_leaf()) {
453  bst_float leaf_value = node.leaf_value();
454  // Fold weight drop into leaf value for dart models.
455  if (!weight_drop.empty()) {
456  leaf_value *= weight_drop[model->trees.size() - 1];
457  }
458  tree.SetLeaf(new_id, static_cast<float>(leaf_value));
459  } else {
460  const bst_float split_cond = node.split_cond();
461  tree.AddChilds(new_id);
462  tree.SetNumericalSplit(new_id, node.split_index(),
463  static_cast<float>(split_cond), node.default_left(), treelite::Operator::kLT);
464  tree.SetGain(new_id, stat.loss_chg);
465  Q.push({node.cleft(), tree.LeftChild(new_id)});
466  Q.push({node.cright(), tree.RightChild(new_id)});
467  }
468  tree.SetSumHess(new_id, stat.sum_hess);
469  }
470  }
471  return model_ptr;
472 }
473 
474 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:651
model structure for tree ensemble
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:619
logging facility for Treelite
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:29
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:639
int LeftChild(int nid) const
Getters.
Definition: tree.h:406
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:413
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:666
Helper functions for loading XGBoost models.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:764
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:675
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:729