Treelite
xgboost_json.cc
Go to the documentation of this file.
1 
9 #include "xgboost/xgboost_json.h"
10 
11 #include <fmt/format.h>
12 #include <rapidjson/error/en.h>
13 #include <rapidjson/document.h>
14 #include <rapidjson/filereadstream.h>
15 #include <treelite/tree.h>
16 #include <treelite/frontend.h>
17 #include <treelite/math.h>
18 #include <treelite/logging.h>
19 
20 #include <algorithm>
21 #include <cstdio>
22 #include <cstdlib>
23 #include <iostream>
24 #include <memory>
25 #include <queue>
26 #include <string>
27 #include <utility>
28 
29 #include "xgboost/xgboost.h"
30 
31 namespace {
32 
33 template <typename StreamType, typename ErrorHandlerFunc>
34 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
35  ErrorHandlerFunc error_handler);
36 
37 } // anonymous namespace
38 
39 namespace treelite {
40 namespace frontend {
41 
42 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename) {
43  char read_buffer[65536];
44 
45 #ifdef _WIN32
46  FILE* fp = std::fopen(filename, "rb");
47 #else
48  FILE* fp = std::fopen(filename, "r");
49 #endif
50  if (!fp) {
51  TREELITE_LOG(FATAL) << "Failed to open file '" << filename << "': " << std::strerror(errno);
52  }
53 
54  auto input_stream = std::make_unique<rapidjson::FileReadStream>(
55  fp, read_buffer, sizeof(read_buffer));
56  auto error_handler = [fp](size_t offset) -> std::string {
57  size_t cur = (offset >= 50 ? (offset - 50) : 0);
58  std::fseek(fp, cur, SEEK_SET);
59  int c;
60  std::ostringstream oss, oss2;
61  for (int i = 0; i < 100; ++i) {
62  c = std::fgetc(fp);
63  if (c == EOF) {
64  break;
65  }
66  oss << static_cast<char>(c);
67  if (cur == offset) {
68  oss2 << "^";
69  } else {
70  oss2 << "~";
71  }
72  ++cur;
73  }
74  std::fclose(fp);
75  return oss.str() + "\n" + oss2.str();
76  };
77  auto parsed_model = ParseStream(std::move(input_stream), error_handler);
78  std::fclose(fp);
79  return parsed_model;
80 }
81 
82 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length) {
83  auto input_stream = std::make_unique<rapidjson::MemoryStream>(json_str, length);
84  auto error_handler = [json_str](size_t offset) -> std::string {
85  size_t cur = (offset >= 50 ? (offset - 50) : 0);
86  std::ostringstream oss, oss2;
87  for (int i = 0; i < 100; ++i) {
88  if (!json_str[cur]) {
89  break;
90  }
91  oss << json_str[cur];
92  if (cur == offset) {
93  oss2 << "^";
94  } else {
95  oss2 << "~";
96  }
97  ++cur;
98  }
99  return oss.str() + "\n" + oss2.str();
100  };
101  return ParseStream(std::move(input_stream), error_handler);
102 }
103 
104 } // namespace frontend
105 
106 namespace details {
107 
108 /******************************************************************************
109  * BaseHandler
110  * ***************************************************************************/
111 
112 bool BaseHandler::pop_handler() {
113  if (auto parent = delegator.lock()) {
114  parent->pop_delegate();
115  return true;
116  } else {
117  return false;
118  }
119 }
120 
121 void BaseHandler::set_cur_key(const char *str, std::size_t length) {
122  cur_key = std::string{str, length};
123 }
124 
125 const std::string &BaseHandler::get_cur_key() { return cur_key; }
126 
127 bool BaseHandler::check_cur_key(const std::string &query_key) {
128  return cur_key == query_key;
129 }
130 
131 template <typename ValueType>
132 bool BaseHandler::assign_value(const std::string &key,
133  ValueType &&value,
134  ValueType &output) {
135  if (check_cur_key(key)) {
136  output = value;
137  return true;
138  } else {
139  return false;
140  }
141 }
142 
143 template <typename ValueType>
144 bool BaseHandler::assign_value(const std::string &key,
145  const ValueType &value,
146  ValueType &output) {
147  if (check_cur_key(key)) {
148  output = value;
149  return true;
150  } else {
151  return false;
152  }
153 }
154 
155 /******************************************************************************
156  * IgnoreHandler
157  * ***************************************************************************/
158 bool IgnoreHandler::Null() { return true; }
159 bool IgnoreHandler::Bool(bool) { return true; }
160 bool IgnoreHandler::Int(int) { return true; }
161 bool IgnoreHandler::Uint(unsigned) { return true; }
162 bool IgnoreHandler::Int64(int64_t) { return true; }
163 bool IgnoreHandler::Uint64(uint64_t) { return true; }
164 bool IgnoreHandler::Double(double) { return true; }
165 bool IgnoreHandler::String(const char *, std::size_t, bool) {
166  return true; }
167 bool IgnoreHandler::StartObject() { return push_handler<IgnoreHandler>(); }
168 bool IgnoreHandler::Key(const char *, std::size_t, bool) {
169  return true; }
170 bool IgnoreHandler::StartArray() { return push_handler<IgnoreHandler>(); }
171 
172 /******************************************************************************
173  * TreeParamHandler
174  * ***************************************************************************/
175 bool TreeParamHandler::String(const char *str, std::size_t, bool) {
176  // Key "num_deleted" deprecated but still present in some xgboost output
177  return (check_cur_key("num_feature") ||
178  assign_value("num_nodes", std::atoi(str), output) ||
179  check_cur_key("size_leaf_vector") || check_cur_key("num_deleted"));
180 }
181 
182 /******************************************************************************
183  * RegTreeHandler
184  * ***************************************************************************/
185 bool RegTreeHandler::StartArray() {
186  /* Keys "categories" and "split_type" not currently documented in schema but
187  * will be used for upcoming categorical split feature */
188  return (
189  push_key_handler<ArrayHandler<double>>("loss_changes", loss_changes) ||
190  push_key_handler<ArrayHandler<double>>("sum_hessian", sum_hessian) ||
191  push_key_handler<ArrayHandler<double>>("base_weights", base_weights) ||
192  push_key_handler<ArrayHandler<int>>("categories_segments", categories_segments) ||
193  push_key_handler<ArrayHandler<int>>("categories_sizes", categories_sizes) ||
194  push_key_handler<ArrayHandler<int>>("categories_nodes", categories_nodes) ||
195  push_key_handler<ArrayHandler<int>>("categories", categories) ||
196  push_key_handler<IgnoreHandler>("leaf_child_counts") ||
197  push_key_handler<ArrayHandler<int>>("left_children", left_children) ||
198  push_key_handler<ArrayHandler<int>>("right_children", right_children) ||
199  push_key_handler<ArrayHandler<int>>("parents", parents) ||
200  push_key_handler<ArrayHandler<int>>("split_indices", split_indices) ||
201  push_key_handler<ArrayHandler<int>>("split_type", split_type) ||
202  push_key_handler<ArrayHandler<double>>("split_conditions", split_conditions) ||
203  push_key_handler<ArrayHandler<bool>>("default_left", default_left));
204 }
205 
206 bool RegTreeHandler::StartObject() {
207  return push_key_handler<TreeParamHandler, int>("tree_param", num_nodes);
208 }
209 
210 bool RegTreeHandler::Uint(unsigned) { return check_cur_key("id"); }
211 
212 bool RegTreeHandler::EndObject(std::size_t) {
213  output.Init();
214  if (split_type.empty()) {
215  split_type.resize(num_nodes, details::xgboost::FeatureType::kNumerical);
216  }
217  if (static_cast<size_t>(num_nodes) != loss_changes.size()) {
218  TREELITE_LOG(ERROR) << "Field loss_changes has an incorrect dimension. Expected: " << num_nodes
219  << ", Actual: " << loss_changes.size();
220  return false;
221  }
222  if (static_cast<size_t>(num_nodes) != sum_hessian.size()) {
223  TREELITE_LOG(ERROR) << "Field sum_hessian has an incorrect dimension. Expected: " << num_nodes
224  << ", Actual: " << sum_hessian.size();
225  return false;
226  }
227  if (static_cast<size_t>(num_nodes) != base_weights.size()) {
228  TREELITE_LOG(ERROR) << "Field base_weights has an incorrect dimension. Expected: " << num_nodes
229  << ", Actual: " << base_weights.size();
230  return false;
231  }
232  if (static_cast<size_t>(num_nodes) != left_children.size()) {
233  TREELITE_LOG(ERROR) << "Field left_children has an incorrect dimension. Expected: " << num_nodes
234  << ", Actual: " << left_children.size();
235  return false;
236  }
237  if (static_cast<size_t>(num_nodes) != right_children.size()) {
238  TREELITE_LOG(ERROR) << "Field right_children has an incorrect dimension. Expected: "
239  << num_nodes << ", Actual: " << right_children.size();
240  return false;
241  }
242  if (static_cast<size_t>(num_nodes) != parents.size()) {
243  TREELITE_LOG(ERROR) << "Field parents has an incorrect dimension. Expected: " << num_nodes
244  << ", Actual: " << parents.size();
245  return false;
246  }
247  if (static_cast<size_t>(num_nodes) != split_indices.size()) {
248  TREELITE_LOG(ERROR) << "Field split_indices has an incorrect dimension. Expected: " << num_nodes
249  << ", Actual: " << split_indices.size();
250  return false;
251  }
252  if (static_cast<size_t>(num_nodes) != split_type.size()) {
253  TREELITE_LOG(ERROR) << "Field split_type has an incorrect dimension. Expected: " << num_nodes
254  << ", Actual: " << split_type.size();
255  return false;
256  }
257  if (static_cast<size_t>(num_nodes) != split_conditions.size()) {
258  TREELITE_LOG(ERROR) << "Field split_conditions has an incorrect dimension. Expected: "
259  << num_nodes << ", Actual: " << split_conditions.size();
260  return false;
261  }
262  if (static_cast<size_t>(num_nodes) != default_left.size()) {
263  TREELITE_LOG(ERROR) << "Field default_left has an incorrect dimension. Expected: " << num_nodes
264  << ", Actual: " << default_left.size();
265  return false;
266  }
267 
268  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
269  if (num_nodes > 0) {
270  Q.push({0, 0});
271  }
272  while (!Q.empty()) {
273  int old_id, new_id;
274  std::tie(old_id, new_id) = Q.front();
275  Q.pop();
276 
277  if (left_children[old_id] == -1) {
278  output.SetLeaf(new_id, split_conditions[old_id]);
279  } else {
280  output.AddChilds(new_id);
281  if (split_type[old_id] == details::xgboost::FeatureType::kCategorical) {
282  auto categorical_split_loc
283  = math::binary_search(categories_nodes.begin(), categories_nodes.end(), old_id);
284  TREELITE_CHECK(categorical_split_loc != categories_nodes.end())
285  << "Could not find record for the categorical split in node " << old_id;
286  auto categorical_split_id = std::distance(categories_nodes.begin(), categorical_split_loc);
287  int offset = categories_segments[categorical_split_id];
288  int num_categories = categories_sizes[categorical_split_id];
289  std::vector<uint32_t> right_categories;
290  right_categories.reserve(num_categories);
291  for (int i = 0; i < num_categories; ++i) {
292  right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
293  }
294  output.SetCategoricalSplit(
295  new_id, split_indices[old_id], default_left[old_id], right_categories, true);
296  } else {
297  output.SetNumericalSplit(
298  new_id, split_indices[old_id], split_conditions[old_id],
299  default_left[old_id], treelite::Operator::kLT);
300  }
301  output.SetGain(new_id, loss_changes[old_id]);
302  Q.push({left_children[old_id], output.LeftChild(new_id)});
303  Q.push({right_children[old_id], output.RightChild(new_id)});
304  }
305  output.SetSumHess(new_id, sum_hessian[old_id]);
306  }
307  return pop_handler();
308 }
309 
310 /******************************************************************************
311  * GBTreeHandler
312  * ***************************************************************************/
313 bool GBTreeModelHandler::StartArray() {
314  return (push_key_handler<ArrayHandler<treelite::Tree<float, float>, RegTreeHandler>,
315  std::vector<treelite::Tree<float, float>>>(
316  "trees", output.trees) ||
317  push_key_handler<IgnoreHandler>("tree_info"));
318 }
319 
320 bool GBTreeModelHandler::StartObject() {
321  return push_key_handler<IgnoreHandler>("gbtree_model_param");
322 }
323 
324 /******************************************************************************
325  * GradientBoosterHandler
326  * ***************************************************************************/
327 bool GradientBoosterHandler::String(const char *str,
328  std::size_t length,
329  bool) {
330  if (assign_value("name", std::string{str, length}, name)) {
331  if (name == "gbtree" || name == "dart") {
332  return true;
333  } else {
334  TREELITE_LOG(ERROR) << "Only GBTree or DART boosters are currently supported.";
335  return false;
336  }
337  } else {
338  return false;
339  }
340 }
341 bool GradientBoosterHandler::StartObject() {
342  if (push_key_handler<GBTreeModelHandler, treelite::ModelImpl<float, float>>("model", output)) {
343  return true;
344  } else if (push_key_handler<GradientBoosterHandler, treelite::ModelImpl<float, float>>("gbtree",
345  output)) {
346  // "dart" booster contains a standard gbtree under ["gradient_booster"]["gbtree"]["model"].
347  return true;
348  } else {
349  TREELITE_LOG(ERROR) << "Key \"" << get_cur_key()
350  << "\" not recognized. Is this a GBTree-type booster?";
351  return false;
352  }
353 }
354 bool GradientBoosterHandler::StartArray() {
355  return push_key_handler<ArrayHandler<double>, std::vector<double>>("weight_drop", weight_drop);
356 }
357 bool GradientBoosterHandler::EndObject(std::size_t memberCount) {
358  if (name == "dart" && !weight_drop.empty()) {
359  // Fold weight drop into leaf value for dart models.
360  TREELITE_CHECK_EQ(output.trees.size(), weight_drop.size());
361  for (size_t i = 0; i < output.trees.size(); ++i) {
362  for (int nid = 0; nid < output.trees[i].num_nodes; ++nid) {
363  if (output.trees[i].IsLeaf(nid)) {
364  output.trees[i].SetLeaf(nid, weight_drop[i] * output.trees[i].LeafValue(nid));
365  }
366  }
367  }
368  }
369  return pop_handler();
370 }
371 
372 /******************************************************************************
373  * ObjectiveHandler
374  * ***************************************************************************/
375 bool ObjectiveHandler::StartObject() {
376  return (push_key_handler<IgnoreHandler>("reg_loss_param") ||
377  push_key_handler<IgnoreHandler>("poisson_regression_param") ||
378  push_key_handler<IgnoreHandler>("tweedie_regression_param") ||
379  push_key_handler<IgnoreHandler>("softmax_multiclass_param") ||
380  push_key_handler<IgnoreHandler>("lambda_rank_param") ||
381  push_key_handler<IgnoreHandler>("aft_loss_param"));
382 }
383 
384 bool ObjectiveHandler::String(const char *str, std::size_t length, bool) {
385  return assign_value("name", std::string{str, length}, output);
386 }
387 
388 /******************************************************************************
389  * LearnerParamHandler
390  * ***************************************************************************/
391 bool LearnerParamHandler::String(const char *str,
392  std::size_t,
393  bool) {
394  int num_target = 1;
395  if (assign_value("num_target", std::atoi(str), num_target)) {
396  if (num_target != 1) {
397  TREELITE_LOG(ERROR)
398  << "num_target must be 1; Treelite doesn't support multi-target regressor yet";
399  return false;
400  }
401  return true;
402  }
403  return (assign_value("base_score", strtof(str, nullptr),
404  output.param.global_bias) ||
405  assign_value("num_class", static_cast<unsigned int>(std::max(std::atoi(str), 1)),
406  output.task_param.num_class) ||
407  assign_value("num_feature", std::atoi(str), output.num_feature));
408 }
409 
410 /******************************************************************************
411  * LearnerHandler
412  * ***************************************************************************/
413 bool LearnerHandler::StartObject() {
414  // "attributes" key is not documented in schema
415  return (push_key_handler<LearnerParamHandler, treelite::ModelImpl<float, float>>(
416  "learner_model_param", *output.model) ||
417  push_key_handler<GradientBoosterHandler, treelite::ModelImpl<float, float>>(
418  "gradient_booster", *output.model) ||
419  push_key_handler<ObjectiveHandler, std::string>("objective", objective) ||
420  push_key_handler<IgnoreHandler>("attributes"));
421 }
422 
423 bool LearnerHandler::EndObject(std::size_t) {
424  xgboost::SetPredTransform(objective, &output.model->param);
425  output.objective_name = objective;
426  return pop_handler();
427 }
428 
429 bool LearnerHandler::StartArray() {
430  return (push_key_handler<IgnoreHandler>("feature_names") ||
431  push_key_handler<IgnoreHandler>("feature_types"));
432 }
433 
434 /******************************************************************************
435  * XGBoostCheckpointHandler
436  * ***************************************************************************/
437 
438 bool XGBoostCheckpointHandler::StartArray() {
439  return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
440  "version", output.version);
441 }
442 
443 bool XGBoostCheckpointHandler::StartObject() {
444  return push_key_handler<LearnerHandler, XGBoostModelHandle>("learner", output);
445 }
446 
447 /******************************************************************************
448  * XGBoostModelHandler
449  * ***************************************************************************/
450 bool XGBoostModelHandler::StartArray() {
451  return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
452  "version", output.version);
453 }
454 
455 bool XGBoostModelHandler::StartObject() {
456  return (push_key_handler<LearnerHandler, XGBoostModelHandle>("learner", output) ||
457  push_key_handler<IgnoreHandler>("Config") ||
458  push_key_handler<XGBoostCheckpointHandler, XGBoostModelHandle>("Model", output));
459 }
460 
461 bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
462  if (memberCount != 2) {
463  TREELITE_LOG(ERROR) << "Expected two members in XGBoostModel";
464  return false;
465  }
466  output.model->average_tree_output = false;
467  output.model->task_param.output_type = TaskParam::OutputType::kFloat;
468  output.model->task_param.leaf_vector_size = 1;
469  if (output.model->task_param.num_class > 1) {
470  // multi-class classifier
471  output.model->task_type = TaskType::kMultiClfGrovePerClass;
472  output.model->task_param.grove_per_class = true;
473  } else {
474  // binary classifier or regressor
475  output.model->task_type = TaskType::kBinaryClfRegr;
476  output.model->task_param.grove_per_class = false;
477  }
478  // Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
479  // 1.0 it's the original value provided by user.
480  const bool need_transform_to_margin = (output.version[0] >= 1);
481  if (need_transform_to_margin) {
482  treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
483  }
484  return pop_handler();
485 }
486 
487 /******************************************************************************
488  * RootHandler
489  * ***************************************************************************/
490 bool RootHandler::StartObject() {
491  handle = {dynamic_cast<treelite::ModelImpl<float, float>*>(output.get()), {}, ""};
492  return push_handler<XGBoostModelHandler, XGBoostModelHandle>(handle);
493 }
494 
495 /******************************************************************************
496  * DelegatedHandler
497  * ***************************************************************************/
498 std::unique_ptr<treelite::Model> DelegatedHandler::get_result() { return std::move(result); }
499 bool DelegatedHandler::Null() { return delegates.top()->Null(); }
500 bool DelegatedHandler::Bool(bool b) { return delegates.top()->Bool(b); }
501 bool DelegatedHandler::Int(int i) { return delegates.top()->Int(i); }
502 bool DelegatedHandler::Uint(unsigned u) { return delegates.top()->Uint(u); }
503 bool DelegatedHandler::Int64(int64_t i) { return delegates.top()->Int64(i); }
504 bool DelegatedHandler::Uint64(uint64_t u) { return delegates.top()->Uint64(u); }
505 bool DelegatedHandler::Double(double d) { return delegates.top()->Double(d); }
506 bool DelegatedHandler::String(const char *str, std::size_t length, bool copy) {
507  return delegates.top()->String(str, length, copy);
508 }
509 bool DelegatedHandler::StartObject() { return delegates.top()->StartObject(); }
510 bool DelegatedHandler::Key(const char *str, std::size_t length, bool copy) {
511  return delegates.top()->Key(str, length, copy);
512 }
513 bool DelegatedHandler::EndObject(std::size_t memberCount) {
514  return delegates.top()->EndObject(memberCount);
515 }
516 bool DelegatedHandler::StartArray() { return delegates.top()->StartArray(); }
517 bool DelegatedHandler::EndArray(std::size_t elementCount) {
518  return delegates.top()->EndArray(elementCount);
519 }
520 
521 
522 } // namespace details
523 } // namespace treelite
524 
525 namespace {
526 template <typename StreamType, typename ErrorHandlerFunc>
527 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
528  ErrorHandlerFunc error_handler) {
529  std::shared_ptr<treelite::details::DelegatedHandler> handler =
531  rapidjson::Reader reader;
532 
533  rapidjson::ParseResult result
534  = reader.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(*input_stream, *handler);
535  if (!result) {
536  const auto error_code = result.Code();
537  const size_t offset = result.Offset();
538  std::string diagnostic = error_handler(offset);
539  TREELITE_LOG(FATAL) << "Provided JSON could not be parsed as XGBoost model. "
540  << "Parsing error at offset " << offset << ": "
541  << rapidjson::GetParseError_En(error_code) << "\n" << diagnostic;
542  }
543  return handler->get_result();
544 }
545 } // anonymous namespace
Some useful math utilities.
Collection of front-end methods to load or construct ensemble model.
model structure for tree ensemble
logging facility for Treelite
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Definition: xgboost_json.h:362
Methods for loading models from XGBoost-style JSON.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:82
Helper functions for loading XGBoost models.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:42