Treelite
xgboost.cc
Go to the documentation of this file.
1 
8 #include "xgboost/xgboost.h"
9 #include <dmlc/data.h>
10 #include <dmlc/memory_io.h>
11 #include <treelite/frontend.h>
12 #include <treelite/tree.h>
13 #include <algorithm>
14 #include <memory>
15 #include <queue>
16 #include <cstring>
17 
18 namespace {
19 
20 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi);
21 
22 } // anonymous namespace
23 
24 namespace treelite {
25 namespace frontend {
26 
27 DMLC_REGISTRY_FILE_TAG(xgboost);
28 
29 std::unique_ptr<treelite::Model> LoadXGBoostModel(const char* filename) {
30  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
31  return ParseStream(fi.get());
32 }
33 
34 std::unique_ptr<treelite::Model> LoadXGBoostModel(const void* buf, size_t len) {
35  dmlc::MemoryFixedSizeStream fs(const_cast<void*>(buf), len);
36  return ParseStream(&fs);
37 }
38 
39 } // namespace frontend
40 } // namespace treelite
41 
42 /* auxiliary data structures to interpret xgboost model file */
43 namespace {
44 
45 typedef float bst_float;
46 
47 /* peekable input stream implemented with a ring buffer */
48 class PeekableInputStream {
49  public:
50  const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes
51 
52  explicit PeekableInputStream(dmlc::Stream* fi)
53  : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
54 
55  inline size_t Read(void* ptr, size_t size) {
56  const size_t bytes_buffered = BytesBuffered();
57  char* cptr = static_cast<char*>(ptr);
58  if (size <= bytes_buffered) {
59  // all content already buffered; consume buffer
60  if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
61  std::memcpy(cptr, &buf_[begin_ptr_], size);
62  begin_ptr_ += size;
63  } else {
64  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
65  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
66  size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
67  begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
68  }
69  return size;
70  } else { // consume buffer entirely and read more bytes
71  const size_t bytes_to_read = size - bytes_buffered;
72  if (begin_ptr_ <= end_ptr_) {
73  std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
74  } else {
75  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
76  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
77  bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
78  }
79  begin_ptr_ = end_ptr_;
80  return bytes_buffered
81  + istm_->Read(cptr + bytes_buffered, bytes_to_read);
82  }
83  }
84 
85  inline size_t PeekRead(void* ptr, size_t size) {
86  CHECK_LE(size, MAX_PEEK_WINDOW)
87  << "PeekableInputStream allows peeking up to "
88  << MAX_PEEK_WINDOW << " bytes";
89  char* cptr = static_cast<char*>(ptr);
90  const size_t bytes_buffered = BytesBuffered();
91  /* fill buffer with additional bytes, up to size */
92  if (size > bytes_buffered) {
93  const size_t bytes_to_read = size - bytes_buffered;
94  if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
95  CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
96  << "Failed to peek " << size << " bytes";
97  end_ptr_ += bytes_to_read;
98  } else {
99  CHECK_EQ(istm_->Read(&buf_[end_ptr_],
100  MAX_PEEK_WINDOW + 1 - end_ptr_)
101  + istm_->Read(&buf_[0],
102  bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
103  bytes_to_read)
104  << "Ill-formed XGBoost model: Failed to peek " << size << " bytes";
105  end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
106  }
107  }
108  /* copy buffer into ptr without emptying buffer */
109  if (begin_ptr_ <= end_ptr_) { // usual case
110  std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
111  } else { // context wrapped around the end
112  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
113  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
114  }
115 
116  return size;
117  }
118 
119  private:
120  dmlc::Stream* istm_;
121  std::vector<char> buf_;
122  size_t begin_ptr_, end_ptr_;
123 
124  inline size_t BytesBuffered() {
125  if (begin_ptr_ <= end_ptr_) { // usual case
126  return end_ptr_ - begin_ptr_;
127  } else { // context wrapped around the end
128  return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
129  }
130  }
131 };
132 
133 template <typename T>
134 inline void CONSUME_BYTES(const T& fi, size_t size) {
135  static std::vector<char> dummy(500);
136  if (size > dummy.size()) dummy.resize(size);
137  CHECK_EQ(fi->Read(&dummy[0], size), size)
138  << "Ill-formed XGBoost model format: cannot read " << size
139  << " bytes from the file";
140 }
141 
142 struct LearnerModelParam {
143  bst_float base_score; // global bias
144  unsigned num_feature;
145  int num_class;
146  int contain_extra_attrs;
147  int contain_eval_metrics;
148  uint32_t major_version;
149  uint32_t minor_version;
150  int pad2[27];
151 };
152 static_assert(sizeof(LearnerModelParam) == 136, "This is the size defined in XGBoost.");
153 
154 struct GBTreeModelParam {
155  int num_trees;
156  int num_roots;
157  int num_feature;
158  int pad1;
159  int64_t pad2;
160  int num_output_group;
161  int size_leaf_vector;
162  int pad3[32];
163 };
164 
165 struct TreeParam {
166  int num_roots;
167  int num_nodes;
168  int num_deleted;
169  int max_depth;
170  int num_feature;
171  int size_leaf_vector;
172  int reserved[31];
173 };
174 
175 struct NodeStat {
176  bst_float loss_chg;
177  bst_float sum_hess;
178  bst_float base_weight;
179  int leaf_child_cnt;
180 };
181 
182 class XGBTree {
183  public:
184  class Node {
185  public:
186  Node() : sindex_(0) {
187  // assert compact alignment
188  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
189  "Node: 64 bit align");
190  }
191  inline int cleft() const {
192  return this->cleft_;
193  }
194  inline int cright() const {
195  return this->cright_;
196  }
197  inline int cdefault() const {
198  return this->default_left() ? this->cleft() : this->cright();
199  }
200  inline unsigned split_index() const {
201  return sindex_ & ((1U << 31) - 1U);
202  }
203  inline bool default_left() const {
204  return (sindex_ >> 31) != 0;
205  }
206  inline bool is_leaf() const {
207  return cleft_ == -1;
208  }
209  inline bst_float leaf_value() const {
210  return (this->info_).leaf_value;
211  }
212  inline bst_float split_cond() const {
213  return (this->info_).split_cond;
214  }
215  inline int parent() const {
216  return parent_ & ((1U << 31) - 1);
217  }
218  inline bool is_root() const {
219  return parent_ == -1;
220  }
221  inline void set_leaf(bst_float value) {
222  (this->info_).leaf_value = value;
223  this->cleft_ = -1;
224  this->cright_ = -1;
225  }
226  inline void set_split(unsigned split_index,
227  bst_float split_cond,
228  bool default_left = false) {
229  if (default_left) split_index |= (1U << 31);
230  this->sindex_ = split_index;
231  (this->info_).split_cond = split_cond;
232  }
233 
234  private:
235  friend class XGBTree;
236  union Info {
237  bst_float leaf_value;
238  bst_float split_cond;
239  };
240  int parent_;
241  int cleft_, cright_;
242  unsigned sindex_;
243  Info info_;
244 
245  inline bool is_deleted() const {
246  return sindex_ == std::numeric_limits<unsigned>::max();
247  }
248  inline void set_parent(int pidx, bool is_left_child = true) {
249  if (is_left_child) pidx |= (1U << 31);
250  this->parent_ = pidx;
251  }
252  };
253 
254  private:
255  TreeParam param;
256  std::vector<Node> nodes;
257  std::vector<NodeStat> stats;
258 
259  inline int AllocNode() {
260  int nd = param.num_nodes++;
261  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
262  << "number of nodes in the tree exceed 2^31";
263  nodes.resize(param.num_nodes);
264  return nd;
265  }
266 
267  public:
269  inline Node& operator[](int nid) {
270  return nodes[nid];
271  }
273  inline const Node& operator[](int nid) const {
274  return nodes[nid];
275  }
277  inline NodeStat& Stat(int nid) {
278  return stats[nid];
279  }
281  inline const NodeStat& Stat(int nid) const {
282  return stats[nid];
283  }
284  inline void Init() {
285  param.num_nodes = 1;
286  nodes.resize(1);
287  nodes[0].set_leaf(0.0f);
288  nodes[0].set_parent(-1);
289  }
290  inline void AddChilds(int nid) {
291  int pleft = this->AllocNode();
292  int pright = this->AllocNode();
293  nodes[nid].cleft_ = pleft;
294  nodes[nid].cright_ = pright;
295  nodes[nodes[nid].cleft() ].set_parent(nid, true);
296  nodes[nodes[nid].cright()].set_parent(nid, false);
297  }
298  inline void Load(PeekableInputStream* fi) {
299  CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
300  << "Ill-formed XGBoost model file: can't read TreeParam";
301  nodes.resize(param.num_nodes);
302  stats.resize(param.num_nodes);
303  CHECK_NE(param.num_nodes, 0)
304  << "Ill-formed XGBoost model file: a tree can't be empty";
305  CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()),
306  sizeof(Node) * nodes.size())
307  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
308  CHECK_EQ(fi->Read(dmlc::BeginPtr(stats), sizeof(NodeStat) * stats.size()),
309  sizeof(NodeStat) * stats.size())
310  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
311  if (param.size_leaf_vector != 0) {
312  uint64_t len;
313  CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
314  << "Ill-formed XGBoost model file";
315  if (len > 0) {
316  CONSUME_BYTES(fi, sizeof(bst_float) * len);
317  }
318  }
319  CHECK_EQ(param.num_roots, 1)
320  << "Invalid XGBoost model file: treelite does not support trees "
321  << "with multiple roots";
322  }
323 };
324 
325 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
326  std::vector<XGBTree> xgb_trees_;
327  LearnerModelParam mparam_; // model parameter
328  GBTreeModelParam gbm_param_; // GBTree training parameter
329  std::string name_gbm_;
330  std::string name_obj_;
331 
332  /* 1. Parse input stream */
333  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
334  // backward compatible header check.
335  std::string header;
336  header.resize(4);
337  if (fp->PeekRead(&header[0], 4) == 4) {
338  CHECK_NE(header, "bs64")
339  << "Ill-formed XGBoost model file: Base64 format no longer supported";
340  if (header == "binf") {
341  CONSUME_BYTES(fp, 4);
342  }
343  }
344  // read parameter
345  CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
346  << "Ill-formed XGBoost model file: corrupted header";
347  {
348  uint64_t len;
349  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
350  << "Ill-formed XGBoost model file: corrupted header";
351  if (len != 0) {
352  name_obj_.resize(len);
353  CHECK_EQ(fp->Read(&name_obj_[0], len), len)
354  << "Ill-formed XGBoost model file: corrupted header";
355  }
356  }
357 
358  {
359  uint64_t len;
360  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
361  << "Ill-formed XGBoost model file: corrupted header";
362  name_gbm_.resize(len);
363  if (len > 0) {
364  CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
365  << "Ill-formed XGBoost model file: corrupted header";
366  }
367  }
368 
369  /* loading GBTree */
370  CHECK(name_gbm_ == "gbtree" || name_gbm_ == "dart")
371  << "Invalid XGBoost model file: "
372  << "Gradient booster must be gbtree or dart type.";
373 
374  CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
375  << "Invalid XGBoost model file: corrupted GBTree parameters";
376  for (int i = 0; i < gbm_param_.num_trees; ++i) {
377  xgb_trees_.emplace_back();
378  xgb_trees_.back().Load(fp.get());
379  }
380  CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";
381  // tree_info is currently unused.
382  std::vector<int> tree_info;
383  tree_info.resize(gbm_param_.num_trees);
384  if (gbm_param_.num_trees != 0) {
385  CHECK_EQ(fp->Read(dmlc::BeginPtr(tree_info), sizeof(int32_t) * tree_info.size()),
386  sizeof(int32_t) * tree_info.size());
387  }
388  // Load weight drop values (per tree) for dart models.
389  std::vector<bst_float> weight_drop;
390  if (name_gbm_ == "dart") {
391  weight_drop.resize(gbm_param_.num_trees);
392  if (gbm_param_.num_trees != 0) {
393  fi->Read(&weight_drop);
394  }
395  }
396 
397  /* 2. Export model */
398  std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<float, float>();
399  auto* model = dynamic_cast<treelite::ModelImpl<float, float>*>(model_ptr.get());
400  model->num_feature = static_cast<int>(mparam_.num_feature);
401  model->average_tree_output = false;
402  const int num_class = std::max(mparam_.num_class, 1);
403  if (num_class > 1) {
404  // multi-class classifier
405  model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
406  model->task_param.grove_per_class = true;
407  } else {
408  // binary classifier or regressor
409  model->task_type = treelite::TaskType::kBinaryClfRegr;
410  model->task_param.grove_per_class = false;
411  }
412  model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat;
413  model->task_param.num_class = num_class;
414  model->task_param.leaf_vector_size = 1;
415 
416  // set correct prediction transform function, depending on objective function
417  treelite::details::xgboost::SetPredTransform(name_obj_, &model->param);
418 
419  // set global bias
420  model->param.global_bias = static_cast<float>(mparam_.base_score);
421  // Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
422  // 1.0 it's the original value provided by user.
423  const bool need_transform_to_margin = mparam_.major_version >= 1;
424  if (need_transform_to_margin) {
425  treelite::details::xgboost::TransformGlobalBiasToMargin(&model->param);
426  }
427 
428  // traverse trees
429  for (const auto& xgb_tree : xgb_trees_) {
430  model->trees.emplace_back();
431  treelite::Tree<float, float>& tree = model->trees.back();
432  tree.Init();
433 
434  // assign node ID's so that a breadth-wise traversal would yield
435  // the monotonic sequence 0, 1, 2, ...
436  // deleted nodes will be excluded
437  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
438  Q.push({0, 0});
439  while (!Q.empty()) {
440  int old_id, new_id;
441  std::tie(old_id, new_id) = Q.front(); Q.pop();
442  const XGBTree::Node& node = xgb_tree[old_id];
443  const NodeStat stat = xgb_tree.Stat(old_id);
444  if (node.is_leaf()) {
445  bst_float leaf_value = node.leaf_value();
446  // Fold weight drop into leaf value for dart models.
447  if (!weight_drop.empty()) {
448  leaf_value *= weight_drop[model->trees.size() - 1];
449  }
450  tree.SetLeaf(new_id, static_cast<float>(leaf_value));
451  } else {
452  const bst_float split_cond = node.split_cond();
453  tree.AddChilds(new_id);
454  tree.SetNumericalSplit(new_id, node.split_index(),
455  static_cast<float>(split_cond), node.default_left(), treelite::Operator::kLT);
456  tree.SetGain(new_id, stat.loss_chg);
457  Q.push({node.cleft(), tree.LeftChild(new_id)});
458  Q.push({node.cright(), tree.RightChild(new_id)});
459  }
460  tree.SetSumHess(new_id, stat.sum_hess);
461  }
462  }
463  return model_ptr;
464 }
465 
466 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:627
model structure for tree ensemble
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:540
std::unique_ptr< treelite::Model > LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:29
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:560
int LeftChild(int nid) const
Getters.
Definition: tree.h:326
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:333
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:640
Helper functions for loading XGBoost models.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:673
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:678
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:728