treelite
xgboost.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <dmlc/memory_io.h>
10 #include <treelite/frontend.h>
11 #include <treelite/tree.h>
12 #include <memory>
13 #include <queue>
14 #include <cstring>
15 
16 namespace {
17 
18 treelite::Model ParseStream(dmlc::Stream* fi);
19 
20 } // anonymous namespace
21 
22 namespace treelite {
23 namespace frontend {
24 
25 DMLC_REGISTRY_FILE_TAG(xgboost);
26 
27 Model LoadXGBoostModel(const char* filename) {
28  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
29  return ParseStream(fi.get());
30 }
31 
32 Model LoadXGBoostModel(const void* buf, size_t len) {
33  dmlc::MemoryFixedSizeStream fs(const_cast<void*>(buf), len);
34  return ParseStream(&fs);
35 }
36 
37 } // namespace frontend
38 } // namespace treelite
39 
40 /* auxiliary data structures to interpret xgboost model file */
41 namespace {
42 
43 typedef float bst_float;
44 
45 /* peekable input stream implemented with a ring buffer */
46 class PeekableInputStream {
47  public:
48  const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes
49 
50  explicit PeekableInputStream(dmlc::Stream* fi)
51  : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
52 
53  inline size_t Read(void* ptr, size_t size) {
54  const size_t bytes_buffered = BytesBuffered();
55  char* cptr = static_cast<char*>(ptr);
56  if (size <= bytes_buffered) {
57  // all content already buffered; consume buffer
58  if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
59  std::memcpy(cptr, &buf_[begin_ptr_], size);
60  begin_ptr_ += size;
61  } else {
62  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
63  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
64  size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
65  begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
66  }
67  return size;
68  } else { // consume buffer entirely and read more bytes
69  const size_t bytes_to_read = size - bytes_buffered;
70  if (begin_ptr_ <= end_ptr_) {
71  std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
72  } else {
73  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
74  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
75  bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
76  }
77  begin_ptr_ = end_ptr_;
78  return bytes_buffered
79  + istm_->Read(cptr + bytes_buffered, bytes_to_read);
80  }
81  }
82 
83  inline size_t PeekRead(void* ptr, size_t size) {
84  CHECK_LE(size, MAX_PEEK_WINDOW)
85  << "PeekableInputStream allows peeking up to "
86  << MAX_PEEK_WINDOW << " bytes";
87  char* cptr = static_cast<char*>(ptr);
88  const size_t bytes_buffered = BytesBuffered();
89  /* fill buffer with additional bytes, up to size */
90  if (size > bytes_buffered) {
91  const size_t bytes_to_read = size - bytes_buffered;
92  if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
93  CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
94  << "Failed to peek " << size << " bytes";
95  end_ptr_ += bytes_to_read;
96  } else {
97  CHECK_EQ(istm_->Read(&buf_[end_ptr_],
98  MAX_PEEK_WINDOW + 1 - end_ptr_)
99  + istm_->Read(&buf_[0],
100  bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
101  bytes_to_read)
102  << "Ill-formed XGBoost model: Failed to peek " << size << " bytes";
103  end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
104  }
105  }
106  /* copy buffer into ptr without emptying buffer */
107  if (begin_ptr_ <= end_ptr_) { // usual case
108  std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
109  } else { // context wrapped around the end
110  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
111  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
112  }
113 
114  return size;
115  }
116 
117  private:
118  dmlc::Stream* istm_;
119  std::vector<char> buf_;
120  size_t begin_ptr_, end_ptr_;
121 
122  inline size_t BytesBuffered() {
123  if (begin_ptr_ <= end_ptr_) { // usual case
124  return end_ptr_ - begin_ptr_;
125  } else { // context wrapped around the end
126  return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
127  }
128  }
129 };
130 
131 template <typename T>
132 inline void CONSUME_BYTES(const T& fi, size_t size) {
133  static std::vector<char> dummy(500);
134  if (size > dummy.size()) dummy.resize(size);
135  CHECK_EQ(fi->Read(&dummy[0], size), size)
136  << "Ill-formed XGBoost model format: cannot read " << size
137  << " bytes from the file";
138 }
139 
140 struct LearnerModelParam {
141  bst_float base_score; // global bias
142  unsigned num_feature;
143  int num_class;
144  int contain_extra_attrs;
145  int contain_eval_metrics;
146  int pad2[29];
147 };
148 
149 struct GBTreeModelParam {
150  int num_trees;
151  int num_roots;
152  int num_feature;
153  int pad1;
154  int64_t pad2;
155  int num_output_group;
156  int size_leaf_vector;
157  int pad3[32];
158 };
159 
160 struct TreeParam {
161  int num_roots;
162  int num_nodes;
163  int num_deleted;
164  int max_depth;
165  int num_feature;
166  int size_leaf_vector;
167  int reserved[31];
168 };
169 
170 struct NodeStat {
171  bst_float loss_chg;
172  bst_float sum_hess;
173  bst_float base_weight;
174  int leaf_child_cnt;
175 };
176 
177 class XGBTree {
178  public:
179  class Node {
180  public:
181  Node() : sindex_(0) {
182  // assert compact alignment
183  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
184  "Node: 64 bit align");
185  }
186  inline int cleft() const {
187  return this->cleft_;
188  }
189  inline int cright() const {
190  return this->cright_;
191  }
192  inline int cdefault() const {
193  return this->default_left() ? this->cleft() : this->cright();
194  }
195  inline unsigned split_index() const {
196  return sindex_ & ((1U << 31) - 1U);
197  }
198  inline bool default_left() const {
199  return (sindex_ >> 31) != 0;
200  }
201  inline bool is_leaf() const {
202  return cleft_ == -1;
203  }
204  inline bst_float leaf_value() const {
205  return (this->info_).leaf_value;
206  }
207  inline bst_float split_cond() const {
208  return (this->info_).split_cond;
209  }
210  inline int parent() const {
211  return parent_ & ((1U << 31) - 1);
212  }
213  inline bool is_root() const {
214  return parent_ == -1;
215  }
216  inline void set_leaf(bst_float value) {
217  (this->info_).leaf_value = value;
218  this->cleft_ = -1;
219  this->cright_ = -1;
220  }
221  inline void set_split(unsigned split_index,
222  bst_float split_cond,
223  bool default_left = false) {
224  if (default_left) split_index |= (1U << 31);
225  this->sindex_ = split_index;
226  (this->info_).split_cond = split_cond;
227  }
228 
229  private:
230  friend class XGBTree;
231  union Info {
232  bst_float leaf_value;
233  bst_float split_cond;
234  };
235  int parent_;
236  int cleft_, cright_;
237  unsigned sindex_;
238  Info info_;
239 
240  inline bool is_deleted() const {
241  return sindex_ == std::numeric_limits<unsigned>::max();
242  }
243  inline void set_parent(int pidx, bool is_left_child = true) {
244  if (is_left_child) pidx |= (1U << 31);
245  this->parent_ = pidx;
246  }
247  };
248 
249  private:
250  TreeParam param;
251  std::vector<Node> nodes;
252  std::vector<NodeStat> stats;
253 
254  inline int AllocNode() {
255  int nd = param.num_nodes++;
256  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
257  << "number of nodes in the tree exceed 2^31";
258  nodes.resize(param.num_nodes);
259  return nd;
260  }
261 
262  public:
264  inline Node& operator[](int nid) {
265  return nodes[nid];
266  }
268  inline const Node& operator[](int nid) const {
269  return nodes[nid];
270  }
272  inline NodeStat& Stat(int nid) {
273  return stats[nid];
274  }
276  inline const NodeStat& Stat(int nid) const {
277  return stats[nid];
278  }
279  inline void Init() {
280  param.num_nodes = 1;
281  nodes.resize(1);
282  nodes[0].set_leaf(0.0f);
283  nodes[0].set_parent(-1);
284  }
285  inline void AddChilds(int nid) {
286  int pleft = this->AllocNode();
287  int pright = this->AllocNode();
288  nodes[nid].cleft_ = pleft;
289  nodes[nid].cright_ = pright;
290  nodes[nodes[nid].cleft() ].set_parent(nid, true);
291  nodes[nodes[nid].cright()].set_parent(nid, false);
292  }
293  inline void Load(PeekableInputStream* fi) {
294  CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
295  << "Ill-formed XGBoost model file: can't read TreeParam";
296  nodes.resize(param.num_nodes);
297  stats.resize(param.num_nodes);
298  CHECK_NE(param.num_nodes, 0)
299  << "Ill-formed XGBoost model file: a tree can't be empty";
300  CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()),
301  sizeof(Node) * nodes.size())
302  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
303  CHECK_EQ(fi->Read(dmlc::BeginPtr(stats), sizeof(NodeStat) * stats.size()),
304  sizeof(NodeStat) * stats.size())
305  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
306  if (param.size_leaf_vector != 0) {
307  uint64_t len;
308  CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
309  << "Ill-formed XGBoost model file";
310  if (len > 0) {
311  CONSUME_BYTES(fi, sizeof(bst_float) * len);
312  }
313  }
314  CHECK_EQ(param.num_roots, 1)
315  << "Invalid XGBoost model file: treelite does not support trees "
316  << "with multiple roots";
317  }
318 };
319 
320 inline treelite::Model ParseStream(dmlc::Stream* fi) {
321  std::vector<XGBTree> xgb_trees_;
322  LearnerModelParam mparam_; // model parameter
323  GBTreeModelParam gbm_param_; // GBTree training parameter
324  std::string name_gbm_;
325  std::string name_obj_;
326 
327  /* 1. Parse input stream */
328  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
329  // backward compatible header check.
330  std::string header;
331  header.resize(4);
332  if (fp->PeekRead(&header[0], 4) == 4) {
333  CHECK_NE(header, "bs64")
334  << "Ill-formed XGBoost model file: Base64 format no longer supported";
335  if (header == "binf") {
336  CONSUME_BYTES(fp, 4);
337  }
338  }
339  // read parameter
340  CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
341  << "Ill-formed XGBoost model file: corrupted header";
342  {
343  // backward compatibility code for compatible with old model type
344  // for new model, Read(&name_obj_) is suffice
345  uint64_t len;
346  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
347  << "Ill-formed XGBoost model file: corrupted header";
348  if (len >= std::numeric_limits<unsigned>::max()) {
349  int gap;
350  CHECK_EQ(fp->Read(&gap, sizeof(gap)), sizeof(gap))
351  << "Ill-formed XGBoost model file: corrupted header";
352  len = len >> static_cast<uint64_t>(32UL);
353  }
354  if (len != 0) {
355  name_obj_.resize(len);
356  CHECK_EQ(fp->Read(&name_obj_[0], len), len)
357  << "Ill-formed XGBoost model file: corrupted header";
358  }
359  }
360 
361  {
362  uint64_t len;
363  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
364  << "Ill-formed XGBoost model file: corrupted header";
365  name_gbm_.resize(len);
366  if (len > 0) {
367  CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
368  << "Ill-formed XGBoost model file: corrupted header";
369  }
370  }
371 
372  /* loading GBTree */
373  CHECK_EQ(name_gbm_, "gbtree")
374  << "Invalid XGBoost model file: "
375  << "Gradient booster must be gbtree type.";
376 
377  CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
378  << "Invalid XGBoost model file: corrupted GBTree parameters";
379  for (int i = 0; i < gbm_param_.num_trees; ++i) {
380  xgb_trees_.emplace_back();
381  xgb_trees_.back().Load(fp.get());
382  }
383  CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";
384 
385  /* 2. Export model */
386  treelite::Model model;
387  model.num_feature = gbm_param_.num_feature;
388  model.num_output_group = gbm_param_.num_output_group;
389  model.random_forest_flag = false;
390 
391  // set global bias
392  model.param.global_bias = static_cast<float>(mparam_.base_score);
393 
394  // set correct prediction transform function, depending on objective function
395  if (name_obj_ == "multi:softmax") {
396  model.param.pred_transform = "max_index";
397  } else if (name_obj_ == "multi:softprob") {
398  model.param.pred_transform = "softmax";
399  } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
400  model.param.pred_transform = "sigmoid";
401  model.param.sigmoid_alpha = 1.0f;
402  } else if (name_obj_ == "count:poisson" || name_obj_ == "reg:gamma"
403  || name_obj_ == "reg:tweedie") {
404  model.param.pred_transform = "exponential";
405  } else {
406  model.param.pred_transform = "identity";
407  }
408 
409  // traverse trees
410  for (const auto& xgb_tree : xgb_trees_) {
411  model.trees.emplace_back();
412  treelite::Tree& tree = model.trees.back();
413  tree.Init();
414 
415  // assign node ID's so that a breadth-wise traversal would yield
416  // the monotonic sequence 0, 1, 2, ...
417  // deleted nodes will be excluded
418  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
419  Q.push({0, 0});
420  while (!Q.empty()) {
421  int old_id, new_id;
422  std::tie(old_id, new_id) = Q.front(); Q.pop();
423  const XGBTree::Node& node = xgb_tree[old_id];
424  const NodeStat stat = xgb_tree.Stat(old_id);
425  if (node.is_leaf()) {
426  const bst_float leaf_value = node.leaf_value();
427  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
428  } else {
429  const bst_float split_cond = node.split_cond();
430  tree.AddChilds(new_id);
431  tree[new_id].set_numerical_split(node.split_index(),
432  static_cast<treelite::tl_float>(split_cond),
433  node.default_left(),
434  treelite::Operator::kLT);
435  tree[new_id].set_gain(stat.loss_chg);
436  Q.push({node.cleft(), tree[new_id].cleft()});
437  Q.push({node.cright(), tree[new_id].cright()});
438  }
439  tree[new_id].set_sum_hess(stat.sum_hess);
440  }
441  }
442  return model;
443 }
444 
445 } // anonymous namespace
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:438
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree.h:312
thin wrapper for tree ensemble model
Definition: tree.h:428
std::vector< Tree > trees
member trees
Definition: tree.h:430
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:392
ModelParam param
extra parameters
Definition: tree.h:443
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:23
float global_bias
global bias of the model
Definition: tree.h:399
std::string pred_transform
name of prediction transform function
Definition: tree.h:384
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:27
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:441
double tl_float
float type to be used internally
Definition: base.h:17
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:322
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:435