treelite
xgboost.cc
Go to the documentation of this file.
1 
8 #include <algorithm>
9 #include <memory>
10 #include <queue>
11 #include <cstring>
12 #include <dmlc/data.h>
13 #include <dmlc/memory_io.h>
14 #include <treelite/frontend.h>
15 #include <treelite/tree.h>
16 
17 namespace {
18 
19 treelite::Model ParseStream(dmlc::Stream* fi);
20 
21 } // anonymous namespace
22 
23 namespace treelite {
24 namespace frontend {
25 
26 DMLC_REGISTRY_FILE_TAG(xgboost);
27 
28 Model LoadXGBoostModel(const char* filename) {
29  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
30  return ParseStream(fi.get());
31 }
32 
33 Model LoadXGBoostModel(const void* buf, size_t len) {
34  dmlc::MemoryFixedSizeStream fs(const_cast<void*>(buf), len);
35  return ParseStream(&fs);
36 }
37 
38 } // namespace frontend
39 } // namespace treelite
40 
41 /* auxiliary data structures to interpret xgboost model file */
42 namespace {
43 
44 typedef float bst_float;
45 
46 struct ProbToMargin {
47  static float Sigmoid(float global_bias) {
48  return -logf(1.0f / global_bias - 1.0f);
49  }
50  static float Exponential(float global_bias) {
51  return logf(global_bias);
52  }
53 };
54 
55 /* peekable input stream implemented with a ring buffer */
56 class PeekableInputStream {
57  public:
58  const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes
59 
60  explicit PeekableInputStream(dmlc::Stream* fi)
61  : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
62 
63  inline size_t Read(void* ptr, size_t size) {
64  const size_t bytes_buffered = BytesBuffered();
65  char* cptr = static_cast<char*>(ptr);
66  if (size <= bytes_buffered) {
67  // all content already buffered; consume buffer
68  if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
69  std::memcpy(cptr, &buf_[begin_ptr_], size);
70  begin_ptr_ += size;
71  } else {
72  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
73  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
74  size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
75  begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
76  }
77  return size;
78  } else { // consume buffer entirely and read more bytes
79  const size_t bytes_to_read = size - bytes_buffered;
80  if (begin_ptr_ <= end_ptr_) {
81  std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
82  } else {
83  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
84  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
85  bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
86  }
87  begin_ptr_ = end_ptr_;
88  return bytes_buffered
89  + istm_->Read(cptr + bytes_buffered, bytes_to_read);
90  }
91  }
92 
93  inline size_t PeekRead(void* ptr, size_t size) {
94  CHECK_LE(size, MAX_PEEK_WINDOW)
95  << "PeekableInputStream allows peeking up to "
96  << MAX_PEEK_WINDOW << " bytes";
97  char* cptr = static_cast<char*>(ptr);
98  const size_t bytes_buffered = BytesBuffered();
99  /* fill buffer with additional bytes, up to size */
100  if (size > bytes_buffered) {
101  const size_t bytes_to_read = size - bytes_buffered;
102  if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
103  CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
104  << "Failed to peek " << size << " bytes";
105  end_ptr_ += bytes_to_read;
106  } else {
107  CHECK_EQ(istm_->Read(&buf_[end_ptr_],
108  MAX_PEEK_WINDOW + 1 - end_ptr_)
109  + istm_->Read(&buf_[0],
110  bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
111  bytes_to_read)
112  << "Ill-formed XGBoost model: Failed to peek " << size << " bytes";
113  end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
114  }
115  }
116  /* copy buffer into ptr without emptying buffer */
117  if (begin_ptr_ <= end_ptr_) { // usual case
118  std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
119  } else { // context wrapped around the end
120  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
121  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
122  }
123 
124  return size;
125  }
126 
127  private:
128  dmlc::Stream* istm_;
129  std::vector<char> buf_;
130  size_t begin_ptr_, end_ptr_;
131 
132  inline size_t BytesBuffered() {
133  if (begin_ptr_ <= end_ptr_) { // usual case
134  return end_ptr_ - begin_ptr_;
135  } else { // context wrapped around the end
136  return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
137  }
138  }
139 };
140 
141 template <typename T>
142 inline void CONSUME_BYTES(const T& fi, size_t size) {
143  static std::vector<char> dummy(500);
144  if (size > dummy.size()) dummy.resize(size);
145  CHECK_EQ(fi->Read(&dummy[0], size), size)
146  << "Ill-formed XGBoost model format: cannot read " << size
147  << " bytes from the file";
148 }
149 
150 struct LearnerModelParam {
151  bst_float base_score; // global bias
152  unsigned num_feature;
153  int num_class;
154  int contain_extra_attrs;
155  int contain_eval_metrics;
156  uint32_t major_version;
157  uint32_t minor_version;
158  int pad2[27];
159 };
160 static_assert(sizeof(LearnerModelParam) == 136, "This is the size defined in XGBoost.");
161 
162 struct GBTreeModelParam {
163  int num_trees;
164  int num_roots;
165  int num_feature;
166  int pad1;
167  int64_t pad2;
168  int num_output_group;
169  int size_leaf_vector;
170  int pad3[32];
171 };
172 
173 struct TreeParam {
174  int num_roots;
175  int num_nodes;
176  int num_deleted;
177  int max_depth;
178  int num_feature;
179  int size_leaf_vector;
180  int reserved[31];
181 };
182 
183 struct NodeStat {
184  bst_float loss_chg;
185  bst_float sum_hess;
186  bst_float base_weight;
187  int leaf_child_cnt;
188 };
189 
190 class XGBTree {
191  public:
192  class Node {
193  public:
194  Node() : sindex_(0) {
195  // assert compact alignment
196  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
197  "Node: 64 bit align");
198  }
199  inline int cleft() const {
200  return this->cleft_;
201  }
202  inline int cright() const {
203  return this->cright_;
204  }
205  inline int cdefault() const {
206  return this->default_left() ? this->cleft() : this->cright();
207  }
208  inline unsigned split_index() const {
209  return sindex_ & ((1U << 31) - 1U);
210  }
211  inline bool default_left() const {
212  return (sindex_ >> 31) != 0;
213  }
214  inline bool is_leaf() const {
215  return cleft_ == -1;
216  }
217  inline bst_float leaf_value() const {
218  return (this->info_).leaf_value;
219  }
220  inline bst_float split_cond() const {
221  return (this->info_).split_cond;
222  }
223  inline int parent() const {
224  return parent_ & ((1U << 31) - 1);
225  }
226  inline bool is_root() const {
227  return parent_ == -1;
228  }
229  inline void set_leaf(bst_float value) {
230  (this->info_).leaf_value = value;
231  this->cleft_ = -1;
232  this->cright_ = -1;
233  }
234  inline void set_split(unsigned split_index,
235  bst_float split_cond,
236  bool default_left = false) {
237  if (default_left) split_index |= (1U << 31);
238  this->sindex_ = split_index;
239  (this->info_).split_cond = split_cond;
240  }
241 
242  private:
243  friend class XGBTree;
244  union Info {
245  bst_float leaf_value;
246  bst_float split_cond;
247  };
248  int parent_;
249  int cleft_, cright_;
250  unsigned sindex_;
251  Info info_;
252 
253  inline bool is_deleted() const {
254  return sindex_ == std::numeric_limits<unsigned>::max();
255  }
256  inline void set_parent(int pidx, bool is_left_child = true) {
257  if (is_left_child) pidx |= (1U << 31);
258  this->parent_ = pidx;
259  }
260  };
261 
262  private:
263  TreeParam param;
264  std::vector<Node> nodes;
265  std::vector<NodeStat> stats;
266 
267  inline int AllocNode() {
268  int nd = param.num_nodes++;
269  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
270  << "number of nodes in the tree exceed 2^31";
271  nodes.resize(param.num_nodes);
272  return nd;
273  }
274 
275  public:
277  inline Node& operator[](int nid) {
278  return nodes[nid];
279  }
281  inline const Node& operator[](int nid) const {
282  return nodes[nid];
283  }
285  inline NodeStat& Stat(int nid) {
286  return stats[nid];
287  }
289  inline const NodeStat& Stat(int nid) const {
290  return stats[nid];
291  }
292  inline void Init() {
293  param.num_nodes = 1;
294  nodes.resize(1);
295  nodes[0].set_leaf(0.0f);
296  nodes[0].set_parent(-1);
297  }
298  inline void AddChilds(int nid) {
299  int pleft = this->AllocNode();
300  int pright = this->AllocNode();
301  nodes[nid].cleft_ = pleft;
302  nodes[nid].cright_ = pright;
303  nodes[nodes[nid].cleft() ].set_parent(nid, true);
304  nodes[nodes[nid].cright()].set_parent(nid, false);
305  }
306  inline void Load(PeekableInputStream* fi) {
307  CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
308  << "Ill-formed XGBoost model file: can't read TreeParam";
309  nodes.resize(param.num_nodes);
310  stats.resize(param.num_nodes);
311  CHECK_NE(param.num_nodes, 0)
312  << "Ill-formed XGBoost model file: a tree can't be empty";
313  CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()),
314  sizeof(Node) * nodes.size())
315  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
316  CHECK_EQ(fi->Read(dmlc::BeginPtr(stats), sizeof(NodeStat) * stats.size()),
317  sizeof(NodeStat) * stats.size())
318  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
319  if (param.size_leaf_vector != 0) {
320  uint64_t len;
321  CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
322  << "Ill-formed XGBoost model file";
323  if (len > 0) {
324  CONSUME_BYTES(fi, sizeof(bst_float) * len);
325  }
326  }
327  CHECK_EQ(param.num_roots, 1)
328  << "Invalid XGBoost model file: treelite does not support trees "
329  << "with multiple roots";
330  }
331 };
332 
333 inline treelite::Model ParseStream(dmlc::Stream* fi) {
334  std::vector<XGBTree> xgb_trees_;
335  LearnerModelParam mparam_; // model parameter
336  GBTreeModelParam gbm_param_; // GBTree training parameter
337  std::string name_gbm_;
338  std::string name_obj_;
339 
340  /* 1. Parse input stream */
341  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
342  // backward compatible header check.
343  std::string header;
344  header.resize(4);
345  if (fp->PeekRead(&header[0], 4) == 4) {
346  CHECK_NE(header, "bs64")
347  << "Ill-formed XGBoost model file: Base64 format no longer supported";
348  if (header == "binf") {
349  CONSUME_BYTES(fp, 4);
350  }
351  }
352  // read parameter
353  CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
354  << "Ill-formed XGBoost model file: corrupted header";
355  {
356  uint64_t len;
357  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
358  << "Ill-formed XGBoost model file: corrupted header";
359  if (len != 0) {
360  name_obj_.resize(len);
361  CHECK_EQ(fp->Read(&name_obj_[0], len), len)
362  << "Ill-formed XGBoost model file: corrupted header";
363  }
364  }
365 
366  {
367  uint64_t len;
368  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
369  << "Ill-formed XGBoost model file: corrupted header";
370  name_gbm_.resize(len);
371  if (len > 0) {
372  CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
373  << "Ill-formed XGBoost model file: corrupted header";
374  }
375  }
376 
377  /* loading GBTree */
378  CHECK_EQ(name_gbm_, "gbtree")
379  << "Invalid XGBoost model file: "
380  << "Gradient booster must be gbtree type.";
381 
382  CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
383  << "Invalid XGBoost model file: corrupted GBTree parameters";
384  for (int i = 0; i < gbm_param_.num_trees; ++i) {
385  xgb_trees_.emplace_back();
386  xgb_trees_.back().Load(fp.get());
387  }
388  CHECK_EQ(gbm_param_.num_roots, 1) << "multi-root trees not supported";
389 
390  // Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
391  // 1.0 it's the original value provided by user.
392  bool need_transform_to_margin = mparam_.major_version >= 1;
393 
394  /* 2. Export model */
395  treelite::Model model;
396  model.num_feature = mparam_.num_feature;
397  model.num_output_group = std::max(mparam_.num_class, 1);
398  model.random_forest_flag = false;
399 
400  // set global bias
401  model.param.global_bias = static_cast<float>(mparam_.base_score);
402  std::vector<std::string> exponential_family {
403  "count:poisson", "reg:gamma", "reg:tweedie"
404  };
405  if (need_transform_to_margin) {
406  if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
407  model.param.global_bias = ProbToMargin::Sigmoid(model.param.global_bias);
408  } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_)
409  != exponential_family.cend()) {
410  model.param.global_bias = ProbToMargin::Exponential(model.param.global_bias);
411  }
412  }
413 
414  // set correct prediction transform function, depending on objective function
415  if (name_obj_ == "multi:softmax") {
416  model.param.pred_transform = "max_index";
417  } else if (name_obj_ == "multi:softprob") {
418  model.param.pred_transform = "softmax";
419  } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
420  model.param.pred_transform = "sigmoid";
421  model.param.sigmoid_alpha = 1.0f;
422  } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_)
423  != exponential_family.cend()) {
424  model.param.pred_transform = "exponential";
425  } else {
426  model.param.pred_transform = "identity";
427  }
428 
429  // traverse trees
430  for (const auto& xgb_tree : xgb_trees_) {
431  model.trees.emplace_back();
432  treelite::Tree& tree = model.trees.back();
433  tree.Init();
434 
435  // assign node ID's so that a breadth-wise traversal would yield
436  // the monotonic sequence 0, 1, 2, ...
437  // deleted nodes will be excluded
438  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
439  Q.push({0, 0});
440  while (!Q.empty()) {
441  int old_id, new_id;
442  std::tie(old_id, new_id) = Q.front(); Q.pop();
443  const XGBTree::Node& node = xgb_tree[old_id];
444  const NodeStat stat = xgb_tree.Stat(old_id);
445  if (node.is_leaf()) {
446  const bst_float leaf_value = node.leaf_value();
447  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
448  } else {
449  const bst_float split_cond = node.split_cond();
450  tree.AddChilds(new_id);
451  tree[new_id].set_numerical_split(node.split_index(),
452  static_cast<treelite::tl_float>(split_cond),
453  node.default_left(),
454  treelite::Operator::kLT);
455  tree[new_id].set_gain(stat.loss_chg);
456  Q.push({node.cleft(), tree[new_id].cleft()});
457  Q.push({node.cright(), tree[new_id].cright()});
458  }
459  tree[new_id].set_sum_hess(stat.sum_hess);
460  }
461  }
462  return model;
463 }
464 
465 } // anonymous namespace
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:438
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree.h:312
thin wrapper for tree ensemble model
Definition: tree.h:428
std::vector< Tree > trees
member trees
Definition: tree.h:430
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:392
ModelParam param
extra parameters
Definition: tree.h:443
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:23
float global_bias
global bias of the model
Definition: tree.h:399
std::string pred_transform
name of prediction transform function
Definition: tree.h:384
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:28
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:441
double tl_float
float type to be used internally
Definition: base.h:17
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:322
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:435