Treelite
xgboost_json.cc
Go to the documentation of this file.
1 
9 #include "xgboost/xgboost_json.h"
10 
11 #include <dmlc/registry.h>
12 #include <dmlc/io.h>
13 #include <fmt/format.h>
14 #include <rapidjson/error/en.h>
15 #include <rapidjson/document.h>
16 #include <rapidjson/filereadstream.h>
17 #include <treelite/tree.h>
18 #include <treelite/frontend.h>
19 #include <treelite/math.h>
20 
21 #include <algorithm>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <iostream>
25 #include <memory>
26 #include <queue>
27 #include <string>
28 #include <utility>
29 
30 #include "xgboost/xgboost.h"
31 
32 namespace {
33 
34 template <typename StreamType, typename ErrorHandlerFunc>
35 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
36  ErrorHandlerFunc error_handler);
37 
38 } // anonymous namespace
39 
40 namespace treelite {
41 namespace frontend {
42 
43 DMLC_REGISTRY_FILE_TAG(xgboost_json);
44 
45 std::unique_ptr<treelite::Model> LoadXGBoostJSONModel(const char* filename) {
46  char read_buffer[65536];
47 
48 #ifdef _WIN32
49  FILE* fp = std::fopen(filename, "rb");
50 #else
51  FILE* fp = std::fopen(filename, "r");
52 #endif
53  if (!fp) {
54  LOG(FATAL) << "Failed to open file '" << filename << "': " << std::strerror(errno);
55  }
56 
57  auto input_stream = std::make_unique<rapidjson::FileReadStream>(
58  fp, read_buffer, sizeof(read_buffer));
59  auto error_handler = [fp](size_t offset) -> std::string {
60  size_t cur = (offset >= 50 ? (offset - 50) : 0);
61  std::fseek(fp, cur, SEEK_SET);
62  int c;
63  std::ostringstream oss, oss2;
64  for (int i = 0; i < 100; ++i) {
65  c = std::fgetc(fp);
66  if (c == EOF) {
67  break;
68  }
69  oss << static_cast<char>(c);
70  if (cur == offset) {
71  oss2 << "^";
72  } else {
73  oss2 << "~";
74  }
75  ++cur;
76  }
77  std::fclose(fp);
78  return oss.str() + "\n" + oss2.str();
79  };
80  auto parsed_model = ParseStream(std::move(input_stream), error_handler);
81  std::fclose(fp);
82  return parsed_model;
83 }
84 
85 std::unique_ptr<treelite::Model> LoadXGBoostJSONModelString(const char* json_str, size_t length) {
86  auto input_stream = std::make_unique<rapidjson::MemoryStream>(json_str, length);
87  auto error_handler = [json_str](size_t offset) -> std::string {
88  size_t cur = (offset >= 50 ? (offset - 50) : 0);
89  std::ostringstream oss, oss2;
90  for (int i = 0; i < 100; ++i) {
91  if (!json_str[cur]) {
92  break;
93  }
94  oss << json_str[cur];
95  if (cur == offset) {
96  oss2 << "^";
97  } else {
98  oss2 << "~";
99  }
100  ++cur;
101  }
102  return oss.str() + "\n" + oss2.str();
103  };
104  return ParseStream(std::move(input_stream), error_handler);
105 }
106 
107 } // namespace frontend
108 
109 namespace details {
110 
111 /******************************************************************************
112  * BaseHandler
113  * ***************************************************************************/
114 
115 bool BaseHandler::pop_handler() {
116  if (auto parent = delegator.lock()) {
117  parent->pop_delegate();
118  return true;
119  } else {
120  return false;
121  }
122 }
123 
124 void BaseHandler::set_cur_key(const char *str, std::size_t length) {
125  cur_key = std::string{str, length};
126 }
127 
128 const std::string &BaseHandler::get_cur_key() { return cur_key; }
129 
130 bool BaseHandler::check_cur_key(const std::string &query_key) {
131  return cur_key == query_key;
132 }
133 
134 template <typename ValueType>
135 bool BaseHandler::assign_value(const std::string &key,
136  ValueType &&value,
137  ValueType &output) {
138  if (check_cur_key(key)) {
139  output = value;
140  return true;
141  } else {
142  return false;
143  }
144 }
145 
146 template <typename ValueType>
147 bool BaseHandler::assign_value(const std::string &key,
148  const ValueType &value,
149  ValueType &output) {
150  if (check_cur_key(key)) {
151  output = value;
152  return true;
153  } else {
154  return false;
155  }
156 }
157 
158 /******************************************************************************
159  * IgnoreHandler
160  * ***************************************************************************/
161 bool IgnoreHandler::Null() { return true; }
162 bool IgnoreHandler::Bool(bool) { return true; }
163 bool IgnoreHandler::Int(int) { return true; }
164 bool IgnoreHandler::Uint(unsigned) { return true; }
165 bool IgnoreHandler::Int64(int64_t) { return true; }
166 bool IgnoreHandler::Uint64(uint64_t) { return true; }
167 bool IgnoreHandler::Double(double) { return true; }
168 bool IgnoreHandler::String(const char *, std::size_t, bool) {
169  return true; }
170 bool IgnoreHandler::StartObject() { return push_handler<IgnoreHandler>(); }
171 bool IgnoreHandler::Key(const char *, std::size_t, bool) {
172  return true; }
173 bool IgnoreHandler::StartArray() { return push_handler<IgnoreHandler>(); }
174 
175 /******************************************************************************
176  * TreeParamHandler
177  * ***************************************************************************/
178 bool TreeParamHandler::String(const char *str, std::size_t, bool) {
179  // Key "num_deleted" deprecated but still present in some xgboost output
180  return (check_cur_key("num_feature") ||
181  assign_value("num_nodes", std::atoi(str), output) ||
182  check_cur_key("size_leaf_vector") || check_cur_key("num_deleted"));
183 }
184 
185 /******************************************************************************
186  * RegTreeHandler
187  * ***************************************************************************/
188 bool RegTreeHandler::StartArray() {
189  /* Keys "categories" and "split_type" not currently documented in schema but
190  * will be used for upcoming categorical split feature */
191  return (
192  push_key_handler<ArrayHandler<double>>("loss_changes", loss_changes) ||
193  push_key_handler<ArrayHandler<double>>("sum_hessian", sum_hessian) ||
194  push_key_handler<ArrayHandler<double>>("base_weights", base_weights) ||
195  push_key_handler<ArrayHandler<int>>("categories_segments", categories_segments) ||
196  push_key_handler<ArrayHandler<int>>("categories_sizes", categories_sizes) ||
197  push_key_handler<ArrayHandler<int>>("categories_nodes", categories_nodes) ||
198  push_key_handler<ArrayHandler<int>>("categories", categories) ||
199  push_key_handler<IgnoreHandler>("leaf_child_counts") ||
200  push_key_handler<ArrayHandler<int>>("left_children", left_children) ||
201  push_key_handler<ArrayHandler<int>>("right_children", right_children) ||
202  push_key_handler<ArrayHandler<int>>("parents", parents) ||
203  push_key_handler<ArrayHandler<int>>("split_indices", split_indices) ||
204  push_key_handler<ArrayHandler<int>>("split_type", split_type) ||
205  push_key_handler<ArrayHandler<double>>("split_conditions", split_conditions) ||
206  push_key_handler<ArrayHandler<bool>>("default_left", default_left));
207 }
208 
209 bool RegTreeHandler::StartObject() {
210  return push_key_handler<TreeParamHandler, int>("tree_param", num_nodes);
211 }
212 
213 bool RegTreeHandler::Uint(unsigned) { return check_cur_key("id"); }
214 
215 bool RegTreeHandler::EndObject(std::size_t) {
216  output.Init();
217  if (split_type.empty()) {
218  split_type.resize(num_nodes, details::xgboost::FeatureType::kNumerical);
219  }
220  if (static_cast<size_t>(num_nodes) != loss_changes.size()) {
221  LOG(ERROR) << "Field loss_changes has an incorrect dimension. Expected: " << num_nodes
222  << ", Actual: " << loss_changes.size();
223  return false;
224  }
225  if (static_cast<size_t>(num_nodes) != sum_hessian.size()) {
226  LOG(ERROR) << "Field sum_hessian has an incorrect dimension. Expected: " << num_nodes
227  << ", Actual: " << sum_hessian.size();
228  return false;
229  }
230  if (static_cast<size_t>(num_nodes) != base_weights.size()) {
231  LOG(ERROR) << "Field base_weights has an incorrect dimension. Expected: " << num_nodes
232  << ", Actual: " << base_weights.size();
233  return false;
234  }
235  if (static_cast<size_t>(num_nodes) != left_children.size()) {
236  LOG(ERROR) << "Field left_children has an incorrect dimension. Expected: " << num_nodes
237  << ", Actual: " << left_children.size();
238  return false;
239  }
240  if (static_cast<size_t>(num_nodes) != right_children.size()) {
241  LOG(ERROR) << "Field right_children has an incorrect dimension. Expected: " << num_nodes
242  << ", Actual: " << right_children.size();
243  return false;
244  }
245  if (static_cast<size_t>(num_nodes) != parents.size()) {
246  LOG(ERROR) << "Field parents has an incorrect dimension. Expected: " << num_nodes
247  << ", Actual: " << parents.size();
248  return false;
249  }
250  if (static_cast<size_t>(num_nodes) != split_indices.size()) {
251  LOG(ERROR) << "Field split_indices has an incorrect dimension. Expected: " << num_nodes
252  << ", Actual: " << split_indices.size();
253  return false;
254  }
255  if (static_cast<size_t>(num_nodes) != split_type.size()) {
256  LOG(ERROR) << "Field split_type has an incorrect dimension. Expected: " << num_nodes
257  << ", Actual: " << split_type.size();
258  return false;
259  }
260  if (static_cast<size_t>(num_nodes) != split_conditions.size()) {
261  LOG(ERROR) << "Field split_conditions has an incorrect dimension. Expected: " << num_nodes
262  << ", Actual: " << split_conditions.size();
263  return false;
264  }
265  if (static_cast<size_t>(num_nodes) != default_left.size()) {
266  LOG(ERROR) << "Field default_left has an incorrect dimension. Expected: " << num_nodes
267  << ", Actual: " << default_left.size();
268  return false;
269  }
270 
271  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
272  if (num_nodes > 0) {
273  Q.push({0, 0});
274  }
275  while (!Q.empty()) {
276  int old_id, new_id;
277  std::tie(old_id, new_id) = Q.front();
278  Q.pop();
279 
280  if (left_children[old_id] == -1) {
281  output.SetLeaf(new_id, split_conditions[old_id]);
282  } else {
283  output.AddChilds(new_id);
284  if (split_type[old_id] == details::xgboost::FeatureType::kCategorical) {
285  auto categorical_split_loc
286  = math::binary_search(categories_nodes.begin(), categories_nodes.end(), old_id);
287  CHECK(categorical_split_loc != categories_nodes.end())
288  << "Could not find record for the categorical split in node " << old_id;
289  auto categorical_split_id = std::distance(categories_nodes.begin(), categorical_split_loc);
290  int offset = categories_segments[categorical_split_id];
291  int num_categories = categories_sizes[categorical_split_id];
292  std::vector<uint32_t> right_categories;
293  right_categories.reserve(num_categories);
294  for (int i = 0; i < num_categories; ++i) {
295  right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
296  }
297  output.SetCategoricalSplit(
298  new_id, split_indices[old_id], default_left[old_id], right_categories, true);
299  } else {
300  output.SetNumericalSplit(
301  new_id, split_indices[old_id], split_conditions[old_id],
302  default_left[old_id], treelite::Operator::kLT);
303  }
304  output.SetGain(new_id, loss_changes[old_id]);
305  Q.push({left_children[old_id], output.LeftChild(new_id)});
306  Q.push({right_children[old_id], output.RightChild(new_id)});
307  }
308  output.SetSumHess(new_id, sum_hessian[old_id]);
309  }
310  return pop_handler();
311 }
312 
313 /******************************************************************************
314  * GBTreeHandler
315  * ***************************************************************************/
316 bool GBTreeModelHandler::StartArray() {
317  return (push_key_handler<ArrayHandler<treelite::Tree<float, float>, RegTreeHandler>,
318  std::vector<treelite::Tree<float, float>>>(
319  "trees", output.trees) ||
320  push_key_handler<IgnoreHandler>("tree_info"));
321 }
322 
323 bool GBTreeModelHandler::StartObject() {
324  return push_key_handler<IgnoreHandler>("gbtree_model_param");
325 }
326 
327 /******************************************************************************
328  * GradientBoosterHandler
329  * ***************************************************************************/
330 bool GradientBoosterHandler::String(const char *str,
331  std::size_t length,
332  bool) {
333  if (!check_cur_key("name")) {
334  return false;
335  }
336  fmt::string_view name{str, length};
337  if (name != "gbtree") {
338  LOG(ERROR) << "Only GBTree-type boosters are currently supported.";
339  return false;
340  } else {
341  return true;
342  }
343 }
344 bool GradientBoosterHandler::StartObject() {
345  if (push_key_handler<GBTreeModelHandler, treelite::ModelImpl<float, float>>("model", output)) {
346  return true;
347  } else {
348  LOG(ERROR) << "Key \"" << get_cur_key()
349  << "\" not recognized. Is this a GBTree-type booster?";
350  return false;
351  }
352 }
353 
354 /******************************************************************************
355  * ObjectiveHandler
356  * ***************************************************************************/
357 bool ObjectiveHandler::StartObject() {
358  return (push_key_handler<IgnoreHandler>("reg_loss_param") ||
359  push_key_handler<IgnoreHandler>("poisson_regression_param") ||
360  push_key_handler<IgnoreHandler>("tweedie_regression_param") ||
361  push_key_handler<IgnoreHandler>("softmax_multiclass_param") ||
362  push_key_handler<IgnoreHandler>("lambda_rank_param") ||
363  push_key_handler<IgnoreHandler>("aft_loss_param"));
364 }
365 
366 bool ObjectiveHandler::String(const char *str, std::size_t length, bool) {
367  return assign_value("name", std::string{str, length}, output);
368 }
369 
370 /******************************************************************************
371  * LearnerParamHandler
372  * ***************************************************************************/
373 bool LearnerParamHandler::String(const char *str,
374  std::size_t,
375  bool) {
376  return (assign_value("base_score", strtof(str, nullptr),
377  output.param.global_bias) ||
378  assign_value("num_class", static_cast<unsigned int>(std::max(std::atoi(str), 1)),
379  output.task_param.num_class) ||
380  assign_value("num_feature", std::atoi(str), output.num_feature));
381 }
382 
383 /******************************************************************************
384  * LearnerHandler
385  * ***************************************************************************/
386 bool LearnerHandler::StartObject() {
387  // "attributes" key is not documented in schema
388  return (push_key_handler<LearnerParamHandler, treelite::ModelImpl<float, float>>(
389  "learner_model_param", *output.model) ||
390  push_key_handler<GradientBoosterHandler, treelite::ModelImpl<float, float>>(
391  "gradient_booster", *output.model) ||
392  push_key_handler<ObjectiveHandler, std::string>("objective", objective) ||
393  push_key_handler<IgnoreHandler>("attributes"));
394 }
395 
396 bool LearnerHandler::EndObject(std::size_t) {
397  xgboost::SetPredTransform(objective, &output.model->param);
398  output.objective_name = objective;
399  return pop_handler();
400 }
401 
402 /******************************************************************************
403  * XGBoostModelHandler
404  * ***************************************************************************/
405 bool XGBoostModelHandler::StartArray() {
406  return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
407  "version", version);
408 }
409 
410 bool XGBoostModelHandler::StartObject() {
411  return push_key_handler<LearnerHandler, XGBoostModelHandle>("learner", output);
412 }
413 
414 bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
415  if (memberCount != 2) {
416  LOG(ERROR) << "Expected two members in XGBoostModel";
417  return false;
418  }
419  output.model->average_tree_output = false;
420  output.model->task_param.output_type = TaskParameter::OutputType::kFloat;
421  output.model->task_param.leaf_vector_size = 1;
422  if (output.model->task_param.num_class > 1) {
423  // multi-class classifier
424  output.model->task_type = TaskType::kMultiClfGrovePerClass;
425  output.model->task_param.grove_per_class = true;
426  } else {
427  // binary classifier or regressor
428  output.model->task_type = TaskType::kBinaryClfRegr;
429  output.model->task_param.grove_per_class = false;
430  }
431  // Before XGBoost 1.0.0, the global bias saved in model is a transformed value. After
432  // 1.0 it's the original value provided by user.
433  const bool need_transform_to_margin = (version[0] >= 1);
434  if (need_transform_to_margin) {
435  treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
436  }
437  return pop_handler();
438 }
439 
440 /******************************************************************************
441  * RootHandler
442  * ***************************************************************************/
443 bool RootHandler::StartObject() {
444  handle = {dynamic_cast<treelite::ModelImpl<float, float>*>(output.get()), ""};
445  return push_handler<XGBoostModelHandler, XGBoostModelHandle>(handle);
446 }
447 
448 /******************************************************************************
449  * DelegatedHandler
450  * ***************************************************************************/
451 std::unique_ptr<treelite::Model> DelegatedHandler::get_result() { return std::move(result); }
452 bool DelegatedHandler::Null() { return delegates.top()->Null(); }
453 bool DelegatedHandler::Bool(bool b) { return delegates.top()->Bool(b); }
454 bool DelegatedHandler::Int(int i) { return delegates.top()->Int(i); }
455 bool DelegatedHandler::Uint(unsigned u) { return delegates.top()->Uint(u); }
456 bool DelegatedHandler::Int64(int64_t i) { return delegates.top()->Int64(i); }
457 bool DelegatedHandler::Uint64(uint64_t u) { return delegates.top()->Uint64(u); }
458 bool DelegatedHandler::Double(double d) { return delegates.top()->Double(d); }
459 bool DelegatedHandler::String(const char *str, std::size_t length, bool copy) {
460  return delegates.top()->String(str, length, copy);
461 }
462 bool DelegatedHandler::StartObject() { return delegates.top()->StartObject(); }
463 bool DelegatedHandler::Key(const char *str, std::size_t length, bool copy) {
464  return delegates.top()->Key(str, length, copy);
465 }
466 bool DelegatedHandler::EndObject(std::size_t memberCount) {
467  return delegates.top()->EndObject(memberCount);
468 }
469 bool DelegatedHandler::StartArray() { return delegates.top()->StartArray(); }
470 bool DelegatedHandler::EndArray(std::size_t elementCount) {
471  return delegates.top()->EndArray(elementCount);
472 }
473 
474 
475 } // namespace details
476 } // namespace treelite
477 
478 namespace {
479 template <typename StreamType, typename ErrorHandlerFunc>
480 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
481  ErrorHandlerFunc error_handler) {
482  std::shared_ptr<treelite::details::DelegatedHandler> handler =
484  rapidjson::Reader reader;
485 
486  rapidjson::ParseResult result
487  = reader.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(*input_stream, *handler);
488  if (!result) {
489  const auto error_code = result.Code();
490  const size_t offset = result.Offset();
491  std::string diagnostic = error_handler(offset);
492  LOG(FATAL) << "Provided JSON could not be parsed as XGBoost model. Parsing error at offset "
493  << offset << ": " << rapidjson::GetParseError_En(error_code) << "\n"
494  << diagnostic;
495  }
496  return handler->get_result();
497 }
498 } // anonymous namespace
Some useful math utilities.
Collection of front-end methods to load or construct ensemble model.
model structure for tree ensemble
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Definition: xgboost_json.h:343
Methods for loading models from XGBoost-style JSON.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Definition: xgboost_json.cc:85
Helper functions for loading XGBoost models.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...
Definition: xgboost_json.cc:45