Treelite
xgboost_json.cc
Go to the documentation of this file.
1 
9 #include "xgboost/xgboost_json.h"
10 
11 #include <dmlc/registry.h>
12 #include <dmlc/io.h>
13 #include <fmt/format.h>
14 #include <rapidjson/error/en.h>
15 #include <rapidjson/document.h>
16 #include <rapidjson/filereadstream.h>
17 #include <treelite/tree.h>
18 #include <treelite/frontend.h>
19 #include <treelite/math.h>
20 
21 #include <algorithm>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <iostream>
25 #include <memory>
26 #include <queue>
27 #include <string>
28 #include <utility>
29 
30 #include "xgboost/xgboost.h"
31 
32 namespace {
33 
34 template <typename StreamType, typename ErrorHandlerFunc>
35 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
36  ErrorHandlerFunc error_handler);
37 
38 } // anonymous namespace
39 
40 namespace treelite {
41 namespace frontend {
42 
43 DMLC_REGISTRY_FILE_TAG(xgboost_json);
44 
45 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename) {
46  char read_buffer[65536];
47 
48 #ifdef _WIN32
49  FILE* fp = std::fopen(filename, "rb");
50 #else
51  FILE* fp = std::fopen(filename, "r");
52 #endif
53  if (!fp) {
54  LOG(FATAL) << "Failed to open file '" << filename << "': " << std::strerror(errno);
55  }
56 
57  auto input_stream = std::make_unique<rapidjson::FileReadStream>(
58  fp, read_buffer, sizeof(read_buffer));
59  auto error_handler = [fp](size_t offset) -> std::string {
60  size_t cur = (offset >= 50 ? (offset - 50) : 0);
61  std::fseek(fp, cur, SEEK_SET);
62  int c;
63  std::ostringstream oss, oss2;
64  for (int i = 0; i < 100; ++i) {
65  c = std::fgetc(fp);
66  if (c == EOF) {
67  break;
68  }
69  oss << static_cast<char>(c);
70  if (cur == offset) {
71  oss2 << "^";
72  } else {
73  oss2 << "~";
74  }
75  ++cur;
76  }
77  std::fclose(fp);
78  return oss.str() + "\n" + oss2.str();
79  };
80  auto parsed_model = ParseStream(std::move(input_stream), error_handler);
81  std::fclose(fp);
82  return parsed_model;
83 }
84 
85 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length) {
86  auto input_stream = std::make_unique<rapidjson::MemoryStream>(json_str, length);
87  auto error_handler = [json_str](size_t offset) -> std::string {
88  size_t cur = (offset >= 50 ? (offset - 50) : 0);
89  std::ostringstream oss, oss2;
90  for (int i = 0; i < 100; ++i) {
91  if (!json_str[cur]) {
92  break;
93  }
94  oss << json_str[cur];
95  if (cur == offset) {
96  oss2 << "^";
97  } else {
98  oss2 << "~";
99  }
100  ++cur;
101  }
102  return oss.str() + "\n" + oss2.str();
103  };
104  return ParseStream(std::move(input_stream), error_handler);
105 }
106 
107 } // namespace frontend
108 
109 namespace details {
110 
111 /******************************************************************************
112  * BaseHandler
113  * ***************************************************************************/
114 
115 bool BaseHandler::pop_handler() {
116  if (auto parent = delegator.lock()) {
117  parent->pop_delegate();
118  return true;
119  } else {
120  return false;
121  }
122 }
123 
124 void BaseHandler::set_cur_key(const char *str, std::size_t length) {
125  cur_key = std::string{str, length};
126 }
127 
128 const std::string &BaseHandler::get_cur_key() { return cur_key; }
129 
130 bool BaseHandler::check_cur_key(const std::string &query_key) {
131  return cur_key == query_key;
132 }
133 
134 template <typename ValueType>
135 bool BaseHandler::assign_value(const std::string &key,
136  ValueType &&value,
137  ValueType &output) {
138  if (check_cur_key(key)) {
139  output = value;
140  return true;
141  } else {
142  return false;
143  }
144 }
145 
146 template <typename ValueType>
147 bool BaseHandler::assign_value(const std::string &key,
148  const ValueType &value,
149  ValueType &output) {
150  if (check_cur_key(key)) {
151  output = value;
152  return true;
153  } else {
154  return false;
155  }
156 }
157 
158 /******************************************************************************
159  * IgnoreHandler
160  * ***************************************************************************/
161 bool IgnoreHandler::Null() { return true; }
162 bool IgnoreHandler::Bool(bool) { return true; }
163 bool IgnoreHandler::Int(int) { return true; }
164 bool IgnoreHandler::Uint(unsigned) { return true; }
165 bool IgnoreHandler::Int64(int64_t) { return true; }
166 bool IgnoreHandler::Uint64(uint64_t) { return true; }
167 bool IgnoreHandler::Double(double) { return true; }
168 bool IgnoreHandler::String(const char *, std::size_t, bool) {
169  return true; }
170 bool IgnoreHandler::StartObject() { return push_handler<IgnoreHandler>(); }
171 bool IgnoreHandler::Key(const char *, std::size_t, bool) {
172  return true; }
173 bool IgnoreHandler::StartArray() { return push_handler<IgnoreHandler>(); }
174 
175 /******************************************************************************
176  * TreeParamHandler
177  * ***************************************************************************/
178 bool TreeParamHandler::String(const char *str, std::size_t, bool) {
179  // Key "num_deleted" deprecated but still present in some xgboost output
180  return (check_cur_key("num_feature") ||
181  assign_value("num_nodes", std::atoi(str), output) ||
182  check_cur_key("size_leaf_vector") || check_cur_key("num_deleted"));
183 }
184 
185 /******************************************************************************
186  * RegTreeHandler
187  * ***************************************************************************/
188 bool RegTreeHandler::StartArray() {
189  /* Keys "categories" and "split_type" not currently documented in schema but
190  * will be used for upcoming categorical split feature */
191  return (
192  push_key_handler<ArrayHandler<double>>("loss_changes", loss_changes) ||
193  push_key_handler<ArrayHandler<double>>("sum_hessian", sum_hessian) ||
194  push_key_handler<ArrayHandler<double>>("base_weights", base_weights) ||
195  push_key_handler<ArrayHandler<int>>("categories_segments", categories_segments) ||
196  push_key_handler<ArrayHandler<int>>("categories_sizes", categories_sizes) ||
197  push_key_handler<ArrayHandler<int>>("categories_nodes", categories_nodes) ||
198  push_key_handler<ArrayHandler<int>>("categories", categories) ||
199  push_key_handler<IgnoreHandler>("leaf_child_counts") ||
200  push_key_handler<ArrayHandler<int>>("left_children", left_children) ||
201  push_key_handler<ArrayHandler<int>>("right_children", right_children) ||
202  push_key_handler<ArrayHandler<int>>("parents", parents) ||
203  push_key_handler<ArrayHandler<int>>("split_indices", split_indices) ||
204  push_key_handler<ArrayHandler<int>>("split_type", split_type) ||
205  push_key_handler<ArrayHandler<double>>("split_conditions", split_conditions) ||
206  push_key_handler<ArrayHandler<bool>>("default_left", default_left));
207 }
208 
209 bool RegTreeHandler::StartObject() {
210  return push_key_handler<TreeParamHandler, int>("tree_param", num_nodes);
211 }
212 
213 bool RegTreeHandler::Uint(unsigned) { return check_cur_key("id"); }
214 
215 bool RegTreeHandler::EndObject(std::size_t) {
216  output.Init();
217  if (split_type.empty()) {
218  split_type.resize(num_nodes, details::xgboost::FeatureType::kNumerical);
219  }
220  if (static_cast<size_t>(num_nodes) != loss_changes.size()) {
221  LOG(ERROR) << "Field loss_changes has an incorrect dimension. Expected: " << num_nodes
222  << ", Actual: " << loss_changes.size();
223  return false;
224  }
225  if (static_cast<size_t>(num_nodes) != sum_hessian.size()) {
226  LOG(ERROR) << "Field sum_hessian has an incorrect dimension. Expected: " << num_nodes
227  << ", Actual: " << sum_hessian.size();
228  return false;
229  }
230  if (static_cast<size_t>(num_nodes) != base_weights.size()) {
231  LOG(ERROR) << "Field base_weights has an incorrect dimension. Expected: " << num_nodes
232  << ", Actual: " << base_weights.size();
233  return false;
234  }
235  if (static_cast<size_t>(num_nodes) != left_children.size()) {
236  LOG(ERROR) << "Field left_children has an incorrect dimension. Expected: " << num_nodes
237  << ", Actual: " << left_children.size();
238  return false;
239  }
240  if (static_cast<size_t>(num_nodes) != right_children.size()) {
241  LOG(ERROR) << "Field right_children has an incorrect dimension. Expected: " << num_nodes
242  << ", Actual: " << right_children.size();
243  return false;
244  }
245  if (static_cast<size_t>(num_nodes) != parents.size()) {
246  LOG(ERROR) << "Field parents has an incorrect dimension. Expected: " << num_nodes
247  << ", Actual: " << parents.size();
248  return false;
249  }
250  if (static_cast<size_t>(num_nodes) != split_indices.size()) {
251  LOG(ERROR) << "Field split_indices has an incorrect dimension. Expected: " << num_nodes
252  << ", Actual: " << split_indices.size();
253  return false;
254  }
255  if (static_cast<size_t>(num_nodes) != split_type.size()) {
256  LOG(ERROR) << "Field split_type has an incorrect dimension. Expected: " << num_nodes
257  << ", Actual: " << split_type.size();
258  return false;
259  }
260  if (static_cast<size_t>(num_nodes) != split_conditions.size()) {
261  LOG(ERROR) << "Field split_conditions has an incorrect dimension. Expected: " << num_nodes
262  << ", Actual: " << split_conditions.size();
263  return false;
264  }
265  if (static_cast<size_t>(num_nodes) != default_left.size()) {
266  LOG(ERROR) << "Field default_left has an incorrect dimension. Expected: " << num_nodes
267  << ", Actual: " << default_left.size();
268  return false;
269  }
270 
271  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
272  if (num_nodes > 0) {
273  Q.push({0, 0});
274  }
275  while (!Q.empty()) {
276  int old_id, new_id;
277  std::tie(old_id, new_id) = Q.front();
278  Q.pop();
279 
280  if (left_children[old_id] == -1) {
281  output.SetLeaf(new_id, split_conditions[old_id]);
282  } else {
283  output.AddChilds(new_id);
284  if (split_type[old_id] == details::xgboost::FeatureType::kCategorical) {
285  auto categorical_split_loc
286  = math::binary_search(categories_nodes.begin(), categories_nodes.end(), old_id);
287  CHECK(categorical_split_loc != categories_nodes.end())
288  << "Could not find record for the categorical split in node " << old_id;
289  auto categorical_split_id = std::distance(categories_nodes.begin(), categorical_split_loc);
290  int offset = categories_segments[categorical_split_id];
291  int num_categories = categories_sizes[categorical_split_id];
292  std::vector<uint32_t> right_categories;
293  right_categories.reserve(num_categories);
294  for (int i = 0; i < num_categories; ++i) {
295  right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
296  }
297  output.SetCategoricalSplit(
298  new_id, split_indices[old_id], default_left[old_id], right_categories, true);
299  } else {
300  output.SetNumericalSplit(
301  new_id, split_indices[old_id], split_conditions[old_id],
302  default_left[old_id], treelite::Operator::kLT);
303  }
304  output.SetGain(new_id, loss_changes[old_id]);
305  Q.push({left_children[old_id], output.LeftChild(new_id)});
306  Q.push({right_children[old_id], output.RightChild(new_id)});
307  }
308  output.SetSumHess(new_id, sum_hessian[old_id]);
309  }
310  return pop_handler();
311 }
312 
313 /******************************************************************************
314  * GBTreeHandler
315  * ***************************************************************************/
316 bool GBTreeModelHandler::StartArray() {
317  return (push_key_handler<ArrayHandler<treelite::Tree<float, float>, RegTreeHandler>,
318  std::vector<treelite::Tree<float, float>>>(
319  "trees", output.trees) ||
320  push_key_handler<IgnoreHandler>("tree_info"));
321 }
322 
323 bool GBTreeModelHandler::StartObject() {
324  return push_key_handler<IgnoreHandler>("gbtree_model_param");
325 }
326 
327 /******************************************************************************
328  * GradientBoosterHandler
329  * ***************************************************************************/
330 bool GradientBoosterHandler::String(const char *str,
331  std::size_t length,
332  bool) {
333  if (assign_value("name", std::string{str, length}, name)) {
334  if (name == "gbtree" || name == "dart") {
335  return true;
336  } else {
337  LOG(ERROR) << "Only GBTree or DART boosters are currently supported.";
338  return false;
339  }
340  } else {
341  return false;
342  }
343 }
344 bool GradientBoosterHandler::StartObject() {
345  if (push_key_handler<GBTreeModelHandler, treelite::ModelImpl<float, float>>("model", output)) {
346  return true;
347  } else if (push_key_handler<GradientBoosterHandler, treelite::ModelImpl<float, float>>("gbtree",
348  output)) {
349  // "dart" booster contains a standard gbtree under ["gradient_booster"]["gbtree"]["model"].
350  return true;
351  } else {
352  LOG(ERROR) << "Key \"" << get_cur_key()
353  << "\" not recognized. Is this a GBTree-type booster?";
354  return false;
355  }
356 }
357 bool GradientBoosterHandler::StartArray() {
358  return push_key_handler<ArrayHandler<double>, std::vector<double>>("weight_drop", weight_drop);
359 }
360 bool GradientBoosterHandler::EndObject(std::size_t memberCount) {
361  if (name == "dart" && !weight_drop.empty()) {
362  // Fold weight drop into leaf value for dart models.
363  CHECK_EQ(output.trees.size(), weight_drop.size());
364  for (size_t i = 0; i < output.trees.size(); ++i) {
365  for (int nid = 0; nid < output.trees[i].num_nodes; ++nid) {
366  if (output.trees[i].IsLeaf(nid)) {
367  output.trees[i].SetLeaf(nid, weight_drop[i] * output.trees[i].LeafValue(nid));
368  }
369  }
370  }
371  }
372  return pop_handler();
373 }
374 
375 /******************************************************************************
376  * ObjectiveHandler
377  * ***************************************************************************/
378 bool ObjectiveHandler::StartObject() {
379  return (push_key_handler<IgnoreHandler>("reg_loss_param") ||
380  push_key_handler<IgnoreHandler>("poisson_regression_param") ||
381  push_key_handler<IgnoreHandler>("tweedie_regression_param") ||
382  push_key_handler<IgnoreHandler>("softmax_multiclass_param") ||
383  push_key_handler<IgnoreHandler>("lambda_rank_param") ||
384  push_key_handler<IgnoreHandler>("aft_loss_param"));
385 }
386 
387 bool ObjectiveHandler::String(const char *str, std::size_t length, bool) {
388  return assign_value("name", std::string{str, length}, output);
389 }
390 
391 /******************************************************************************
392  * LearnerParamHandler
393  * ***************************************************************************/
394 bool LearnerParamHandler::String(const char *str,
395  std::size_t,
396  bool) {
397  return (assign_value("base_score", strtof(str, nullptr),
398  output.param.global_bias) ||
399  assign_value("num_class", static_cast<unsigned int>(std::max(std::atoi(str), 1)),
400  output.task_param.num_class) ||
401  assign_value("num_feature", std::atoi(str), output.num_feature));
402 }
403 
404 /******************************************************************************
405  * LearnerHandler
406  * ***************************************************************************/
407 bool LearnerHandler::StartObject() {
408  // "attributes" key is not documented in schema
409  return (push_key_handler<LearnerParamHandler, treelite::ModelImpl<float, float>>(
410  "learner_model_param", *output.model) ||
411  push_key_handler<GradientBoosterHandler, treelite::ModelImpl<float, float>>(
412  "gradient_booster", *output.model) ||
413  push_key_handler<ObjectiveHandler, std::string>("objective", objective) ||
414  push_key_handler<IgnoreHandler>("attributes"));
415 }
416 
417 bool LearnerHandler::EndObject(std::size_t) {
418  xgboost::SetPredTransform(objective, &output.model->param);
419  output.objective_name = objective;
420  return pop_handler();
421 }
422 
423 bool LearnerHandler::StartArray() {
424  return (push_key_handler<IgnoreHandler>("feature_names") ||
425  push_key_handler<IgnoreHandler>("feature_types"));
426 }
427 
428 /******************************************************************************
429  * XGBoostModelHandler
430  * ***************************************************************************/
431 bool XGBoostModelHandler::StartArray() {
432  return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
433  "version", version);
434 }
435 
436 bool XGBoostModelHandler::StartObject() {
437  return push_key_handler<LearnerHandler, XGBoostModelHandle>("learner", output);
438 }
439 
440 bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
441  if (memberCount != 2) {
442  LOG(ERROR) << "Expected two members in XGBoostModel";
443  return false;
444  }
445  output.model->average_tree_output = false;
446  output.model->task_param.output_type = TaskParameter::OutputType::kFloat;
447  output.model->task_param.leaf_vector_size = 1;
448  if (output.model->task_param.num_class > 1) {
449  // multi-class classifier
450  output.model->task_type = TaskType::kMultiClfGrovePerClass;
451  output.model->task_param.grove_per_class = true;
452  } else {
453  // binary classifier or regressor
454  output.model->task_type = TaskType::kBinaryClfRegr;
455  output.model->task_param.grove_per_class = false;
456  }
457  // Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
458  // 1.0 it's the original value provided by user.
459  const bool need_transform_to_margin = (version[0] >= 1);
460  if (need_transform_to_margin) {
461  treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
462  }
463  return pop_handler();
464 }
465 
466 /******************************************************************************
467  * RootHandler
468  * ***************************************************************************/
469 bool RootHandler::StartObject() {
470  handle = {dynamic_cast<treelite::ModelImpl<float, float>*>(output.get()), ""};
471  return push_handler<XGBoostModelHandler, XGBoostModelHandle>(handle);
472 }
473 
474 /******************************************************************************
475  * DelegatedHandler
476  * ***************************************************************************/
477 std::unique_ptr<treelite::Model> DelegatedHandler::get_result() { return std::move(result); }
478 bool DelegatedHandler::Null() { return delegates.top()->Null(); }
479 bool DelegatedHandler::Bool(bool b) { return delegates.top()->Bool(b); }
480 bool DelegatedHandler::Int(int i) { return delegates.top()->Int(i); }
481 bool DelegatedHandler::Uint(unsigned u) { return delegates.top()->Uint(u); }
482 bool DelegatedHandler::Int64(int64_t i) { return delegates.top()->Int64(i); }
483 bool DelegatedHandler::Uint64(uint64_t u) { return delegates.top()->Uint64(u); }
484 bool DelegatedHandler::Double(double d) { return delegates.top()->Double(d); }
485 bool DelegatedHandler::String(const char *str, std::size_t length, bool copy) {
486  return delegates.top()->String(str, length, copy);
487 }
488 bool DelegatedHandler::StartObject() { return delegates.top()->StartObject(); }
489 bool DelegatedHandler::Key(const char *str, std::size_t length, bool copy) {
490  return delegates.top()->Key(str, length, copy);
491 }
492 bool DelegatedHandler::EndObject(std::size_t memberCount) {
493  return delegates.top()->EndObject(memberCount);
494 }
495 bool DelegatedHandler::StartArray() { return delegates.top()->StartArray(); }
496 bool DelegatedHandler::EndArray(std::size_t elementCount) {
497  return delegates.top()->EndArray(elementCount);
498 }
499 
500 
501 } // namespace details
502 } // namespace treelite
503 
504 namespace {
505 template <typename StreamType, typename ErrorHandlerFunc>
506 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
507  ErrorHandlerFunc error_handler) {
508  std::shared_ptr<treelite::details::DelegatedHandler> handler =
510  rapidjson::Reader reader;
511 
512  rapidjson::ParseResult result
513  = reader.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(*input_stream, *handler);
514  if (!result) {
515  const auto error_code = result.Code();
516  const size_t offset = result.Offset();
517  std::string diagnostic = error_handler(offset);
518  LOG(FATAL) << "Provided JSON could not be parsed as XGBoost model. Parsing error at offset "
519  << offset << ": " << rapidjson::GetParseError_En(error_code) << "\n"
520  << diagnostic;
521  }
522  return handler->get_result();
523 }
524 } // anonymous namespace
Some useful math utilities.
Collection of front-end methods to load or construct ensemble model.
model structure for tree ensemble
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Definition: xgboost_json.h:349
Methods for loading models from XGBoost-style JSON.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:85
Helper functions for loading XGBoost models.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:45