treelite
xgboost.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <dmlc/memory_io.h>
10 #include <treelite/tree.h>
11 #include <memory>
12 #include <queue>
13 #include <cstring>
14 
15 namespace {
16 
17 treelite::Model ParseStream(dmlc::Stream* fi);
18 
19 } // namespace anonymous
20 
21 namespace treelite {
22 namespace frontend {
23 
24 DMLC_REGISTRY_FILE_TAG(xgboost);
25 
26 Model LoadXGBoostModel(const char* filename) {
27  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
28  return ParseStream(fi.get());
29 }
30 
31 Model LoadXGBoostModel(const void* buf, size_t len) {
32  dmlc::MemoryFixedSizeStream fs((void*)buf, len);
33  return ParseStream(&fs);
34 }
35 
36 } // namespace frontend
37 } // namespace treelite
38 
39 /* auxiliary data structures to interpret xgboost model file */
40 namespace {
41 
42 typedef float bst_float;
43 
44 /* peekable input stream implemented with a ring buffer */
45 class PeekableInputStream {
46  public:
47  const size_t MAX_PEEK_WINDOW = 1024; // peek up to 1024 bytes
48 
49  PeekableInputStream(dmlc::Stream* fi)
50  : istm_(fi), buf_(MAX_PEEK_WINDOW + 1), begin_ptr_(0), end_ptr_(0) {}
51 
52  inline size_t Read(void* ptr, size_t size) {
53  const size_t bytes_buffered = BytesBuffered();
54  char* cptr = static_cast<char*>(ptr);
55  if (size <= bytes_buffered) {
56  // all content already buffered; consume buffer
57  if (begin_ptr_ + size < MAX_PEEK_WINDOW + 1) {
58  std::memcpy(cptr, &buf_[begin_ptr_], size);
59  begin_ptr_ += size;
60  } else {
61  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
62  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
63  size + begin_ptr_ - MAX_PEEK_WINDOW - 1);
64  begin_ptr_ = size + begin_ptr_ - MAX_PEEK_WINDOW - 1;
65  }
66  return size;
67  } else { // consume buffer entirely and read more bytes
68  const size_t bytes_to_read = size - bytes_buffered;
69  if (begin_ptr_ <= end_ptr_) {
70  std::memcpy(cptr, &buf_[begin_ptr_], bytes_buffered);
71  } else {
72  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
73  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0],
74  bytes_buffered + begin_ptr_ - MAX_PEEK_WINDOW - 1);
75  }
76  begin_ptr_ = end_ptr_;
77  return bytes_buffered
78  + istm_->Read(cptr + bytes_buffered, bytes_to_read);
79  }
80  }
81 
82  inline size_t PeekRead(void* ptr, size_t size) {
83  CHECK_LE(size, MAX_PEEK_WINDOW)
84  << "PeekableInputStream allows peeking up to "
85  << MAX_PEEK_WINDOW << " bytes";
86  char* cptr = static_cast<char*>(ptr);
87  const size_t bytes_buffered = BytesBuffered();
88  /* fill buffer with additional bytes, up to size */
89  if (size > bytes_buffered) {
90  const size_t bytes_to_read = size - bytes_buffered;
91  if (end_ptr_ + bytes_to_read < MAX_PEEK_WINDOW + 1) {
92  CHECK_EQ(istm_->Read(&buf_[end_ptr_], bytes_to_read), bytes_to_read)
93  << "Failed to peek " << size << " bytes";
94  end_ptr_ += bytes_to_read;
95  } else {
96  CHECK_EQ( istm_->Read(&buf_[end_ptr_],
97  MAX_PEEK_WINDOW + 1 - end_ptr_)
98  + istm_->Read(&buf_[0],
99  bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1),
100  bytes_to_read)
101  << "Ill-formed XGBoost model: Failed to peek " << size << " bytes";
102  end_ptr_ = bytes_to_read + end_ptr_ - MAX_PEEK_WINDOW - 1;
103  }
104  }
105  /* copy buffer into ptr without emptying buffer */
106  if (begin_ptr_ <= end_ptr_) { // usual case
107  std::memcpy(cptr, &buf_[begin_ptr_], end_ptr_ - begin_ptr_);
108  } else { // context wrapped around the end
109  std::memcpy(cptr, &buf_[begin_ptr_], MAX_PEEK_WINDOW + 1 - begin_ptr_);
110  std::memcpy(cptr + MAX_PEEK_WINDOW + 1 - begin_ptr_, &buf_[0], end_ptr_);
111  }
112 
113  return size;
114  }
115 
116  private:
117  dmlc::Stream* istm_;
118  std::vector<char> buf_;
119  size_t begin_ptr_, end_ptr_;
120 
121  inline size_t BytesBuffered() {
122  if (begin_ptr_ <= end_ptr_) { // usual case
123  return end_ptr_ - begin_ptr_;
124  } else { // context wrapped around the end
125  return MAX_PEEK_WINDOW + 1 + end_ptr_ - begin_ptr_;
126  }
127  }
128 };
129 
130 template <typename T>
131 inline void CONSUME_BYTES(const T& fi, size_t size) {
132  static std::vector<char> dummy(500);
133  if (size > dummy.size()) dummy.resize(size);
134  CHECK_EQ(fi->Read(&dummy[0], size), size)
135  << "Ill-formed XGBoost model format: cannot read " << size
136  << " bytes from the file";
137 }
138 
139 struct LearnerModelParam {
140  bst_float base_score; // global bias
141  unsigned pad1;
142  int pad2[32];
143 };
144 
145 struct GBTreeModelParam {
146  int num_trees;
147  int pad0;
148  int num_feature;
149  int pad1;
150  int64_t pad2;
151  int num_output_group;
152  int pad3[33];
153 };
154 
155 struct TreeParam {
156  int num_roots;
157  int num_nodes;
158  int num_deleted;
159  int pad0[2];
160  int size_leaf_vector;
161  int reserved[31];
162 };
163 
164 class XGBTree {
165  public:
166  class Node {
167  public:
168  Node() : sindex_(0) {
169  // assert compact alignment
170  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
171  "Node: 64 bit align");
172  }
173  inline int cleft() const {
174  return this->cleft_;
175  }
176  inline int cright() const {
177  return this->cright_;
178  }
179  inline int cdefault() const {
180  return this->default_left() ? this->cleft() : this->cright();
181  }
182  inline unsigned split_index() const {
183  return sindex_ & ((1U << 31) - 1U);
184  }
185  inline bool default_left() const {
186  return (sindex_ >> 31) != 0;
187  }
188  inline bool is_leaf() const {
189  return cleft_ == -1;
190  }
191  inline bst_float leaf_value() const {
192  return (this->info_).leaf_value;
193  }
194  inline bst_float split_cond() const {
195  return (this->info_).split_cond;
196  }
197  inline int parent() const {
198  return parent_ & ((1U << 31) - 1);
199  }
200  inline bool is_root() const {
201  return parent_ == -1;
202  }
203 
204  private:
205  friend class XGBTree;
206  union Info {
207  bst_float leaf_value;
208  bst_float split_cond;
209  };
210  int parent_;
211  int cleft_, cright_;
212  unsigned sindex_;
213  Info info_;
214 
215  inline bool is_deleted() const {
216  return sindex_ == std::numeric_limits<unsigned>::max();
217  }
218  };
219 
220  private:
221  TreeParam param;
222  std::vector<Node> nodes;
223 
224  public:
225  inline Node& operator[](int nid) {
226  return nodes[nid];
227  }
228  inline const Node& operator[](int nid) const {
229  return nodes[nid];
230  }
231  inline void Load(PeekableInputStream* fi) {
232  CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam))
233  << "Ill-formed XGBoost model file: can't read TreeParam";
234  nodes.resize(param.num_nodes);
235  CHECK_NE(param.num_nodes, 0)
236  << "Ill-formed XGBoost model file: a tree can't be empty";
237  CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes), sizeof(Node) * nodes.size()),
238  sizeof(Node) * nodes.size())
239  << "Ill-formed XGBoost model file: cannot read specified number of nodes";
240  CONSUME_BYTES(fi, (3 * sizeof(bst_float) + sizeof(int)) * param.num_nodes);
241  if (param.size_leaf_vector != 0) {
242  uint64_t len;
243  CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
244  << "Ill-formed XGBoost model file";
245  if (len > 0) {
246  CONSUME_BYTES(fi, sizeof(bst_float) * len);
247  }
248  }
249  CHECK_EQ(param.num_roots, 1)
250  << "Invalid XGBoost model file: treelite does not support trees "
251  << "with multiple roots";
252  }
253 };
254 
255 inline treelite::Model ParseStream(dmlc::Stream* fi) {
256  std::vector<XGBTree> xgb_trees_;
257  LearnerModelParam mparam_; // model parameter
258  GBTreeModelParam gbm_param_; // GBTree training parameter
259  std::string name_gbm_;
260  std::string name_obj_;
261 
262  /* 1. Parse input stream */
263  std::unique_ptr<PeekableInputStream> fp(new PeekableInputStream(fi));
264  // backward compatible header check.
265  std::string header;
266  header.resize(4);
267  if (fp->PeekRead(&header[0], 4) == 4) {
268  CHECK_NE(header, "bs64")
269  << "Ill-formed XGBoost model file: Base64 format no longer supported";
270  if (header == "binf") {
271  CONSUME_BYTES(fp, 4);
272  }
273  }
274  // read parameter
275  CHECK_EQ(fp->Read(&mparam_, sizeof(mparam_)), sizeof(mparam_))
276  << "Ill-formed XGBoost model file: corrupted header";
277  LOG(INFO) << "Global bias of the model: " << mparam_.base_score;
278  {
279  // backward compatibility code for compatible with old model type
280  // for new model, Read(&name_obj_) is suffice
281  uint64_t len;
282  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
283  << "Ill-formed XGBoost model file: corrupted header";
284  if (len >= std::numeric_limits<unsigned>::max()) {
285  int gap;
286  CHECK_EQ(fp->Read(&gap, sizeof(gap)), sizeof(gap))
287  << "Ill-formed XGBoost model file: corrupted header";
288  len = len >> static_cast<uint64_t>(32UL);
289  }
290  if (len != 0) {
291  name_obj_.resize(len);
292  CHECK_EQ(fp->Read(&name_obj_[0], len), len)
293  << "Ill-formed XGBoost model file: corrupted header";
294  }
295  }
296 
297  {
298  uint64_t len;
299  CHECK_EQ(fp->Read(&len, sizeof(len)), sizeof(len))
300  << "Ill-formed XGBoost model file: corrupted header";
301  name_gbm_.resize(len);
302  if (len > 0) {
303  CHECK_EQ(fp->Read(&name_gbm_[0], len), len)
304  << "Ill-formed XGBoost model file: corrupted header";
305  }
306  }
307 
308  /* loading GBTree */
309  CHECK_EQ(name_gbm_, "gbtree")
310  << "Invalid XGBoost model file: "
311  << "Gradient booster must be gbtree type.";
312 
313  CHECK_EQ(fp->Read(&gbm_param_, sizeof(gbm_param_)), sizeof(gbm_param_))
314  << "Invalid XGBoost model file: corrupted GBTree parameters";
315  LOG(INFO) << "gbm_param_.num_feature = " << gbm_param_.num_feature;
316  LOG(INFO) << "gbm_param_.num_output_group = " << gbm_param_.num_output_group;
317  for (int i = 0; i < gbm_param_.num_trees; ++i) {
318  xgb_trees_.emplace_back();
319  xgb_trees_.back().Load(fp.get());
320  }
321 
322  /* 2. Export model */
323  treelite::Model model;
324  model.num_feature = gbm_param_.num_feature;
325  model.num_output_group = gbm_param_.num_output_group;
326  model.random_forest_flag = false;
327 
328  // set global bias
329  model.param.global_bias = static_cast<float>(mparam_.base_score);
330 
331  // set correct prediction transform function, depending on objective function
332  if (name_obj_ == "multi:softmax") {
333  model.param.pred_transform = "max_index";
334  } else if (name_obj_ == "multi:softprob") {
335  model.param.pred_transform = "softmax";
336  } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") {
337  model.param.pred_transform = "sigmoid";
338  model.param.sigmoid_alpha = 1.0f;
339  } else if (name_obj_ == "count:poisson" || name_obj_ == "reg:gamma"
340  || name_obj_ == "reg:tweedie") {
341  model.param.pred_transform = "exponential";
342  } else {
343  model.param.pred_transform = "identity";
344  }
345 
346  // traverse trees
347  for (const auto& xgb_tree : xgb_trees_) {
348  model.trees.emplace_back();
349  treelite::Tree& tree = model.trees.back();
350  tree.Init();
351 
352  // assign node ID's so that a breadth-wise traversal would yield
353  // the monotonic sequence 0, 1, 2, ...
354  // deleted nodes will be excluded
355  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
356  Q.push({0, 0});
357  while (!Q.empty()) {
358  int old_id, new_id;
359  std::tie(old_id, new_id) = Q.front(); Q.pop();
360  const XGBTree::Node& node = xgb_tree[old_id];
361  if (node.is_leaf()) {
362  const bst_float leaf_value = node.leaf_value();
363  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
364  } else {
365  const bst_float split_cond = node.split_cond();
366  tree.AddChilds(new_id);
367  tree[new_id].set_numerical_split(node.split_index(),
368  static_cast<treelite::tl_float>(split_cond),
369  node.default_left(),
370  treelite::Operator::kLT);
371  Q.push({node.cleft(), tree[new_id].cleft()});
372  Q.push({node.cright(), tree[new_id].cright()});
373  }
374  }
375  }
376  return model;
377 }
378 
379 } // namespace anonymous
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
Definition: tree.h:360
void Init()
initialize the model with a single root node
Definition: tree.h:234
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
std::vector< Tree > trees
member trees
Definition: tree.h:352
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:314
ModelParam param
extra parameters
Definition: tree.h:365
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:19
float global_bias
global bias of the model
Definition: tree.h:321
std::string pred_transform
name of prediction transform function
Definition: tree.h:306
Model LoadXGBoostModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost.cc:26
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
Definition: tree.h:363
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:244
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:357