11 #include <dmlc/registry.h> 13 #include <fmt/format.h> 14 #include <rapidjson/error/en.h> 15 #include <rapidjson/document.h> 16 #include <rapidjson/filereadstream.h> 34 template <
typename StreamType,
typename ErrorHandlerFunc>
35 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
36 ErrorHandlerFunc error_handler);
43 DMLC_REGISTRY_FILE_TAG(xgboost_json);
46 char read_buffer[65536];
49 FILE* fp = std::fopen(filename,
"rb");
51 FILE* fp = std::fopen(filename,
"r");
54 LOG(FATAL) <<
"Failed to open file '" << filename <<
"': " << std::strerror(errno);
57 auto input_stream = std::make_unique<rapidjson::FileReadStream>(
58 fp, read_buffer,
sizeof(read_buffer));
59 auto error_handler = [fp](
size_t offset) -> std::string {
60 size_t cur = (offset >= 50 ? (offset - 50) : 0);
61 std::fseek(fp, cur, SEEK_SET);
63 std::ostringstream oss, oss2;
64 for (
int i = 0; i < 100; ++i) {
69 oss << static_cast<char>(c);
78 return oss.str() +
"\n" + oss2.str();
80 auto parsed_model = ParseStream(std::move(input_stream), error_handler);
86 auto input_stream = std::make_unique<rapidjson::MemoryStream>(json_str, length);
87 auto error_handler = [json_str](
size_t offset) -> std::string {
88 size_t cur = (offset >= 50 ? (offset - 50) : 0);
89 std::ostringstream oss, oss2;
90 for (
int i = 0; i < 100; ++i) {
102 return oss.str() +
"\n" + oss2.str();
104 return ParseStream(std::move(input_stream), error_handler);
115 bool BaseHandler::pop_handler() {
116 if (
auto parent = delegator.lock()) {
117 parent->pop_delegate();
124 void BaseHandler::set_cur_key(
const char *str, std::size_t length) {
125 cur_key = std::string{str, length};
128 const std::string &BaseHandler::get_cur_key() {
return cur_key; }
130 bool BaseHandler::check_cur_key(
const std::string &query_key) {
131 return cur_key == query_key;
134 template <
typename ValueType>
135 bool BaseHandler::assign_value(
const std::string &key,
138 if (check_cur_key(key)) {
146 template <
typename ValueType>
147 bool BaseHandler::assign_value(
const std::string &key,
148 const ValueType &value,
150 if (check_cur_key(key)) {
161 bool IgnoreHandler::Null() {
return true; }
162 bool IgnoreHandler::Bool(
bool) {
return true; }
163 bool IgnoreHandler::Int(
int) {
return true; }
164 bool IgnoreHandler::Uint(
unsigned) {
return true; }
165 bool IgnoreHandler::Int64(int64_t) {
return true; }
166 bool IgnoreHandler::Uint64(uint64_t) {
return true; }
167 bool IgnoreHandler::Double(
double) {
return true; }
168 bool IgnoreHandler::String(
const char *, std::size_t,
bool) {
170 bool IgnoreHandler::StartObject() {
return push_handler<IgnoreHandler>(); }
171 bool IgnoreHandler::Key(
const char *, std::size_t,
bool) {
173 bool IgnoreHandler::StartArray() {
return push_handler<IgnoreHandler>(); }
178 bool TreeParamHandler::String(
const char *str, std::size_t,
bool) {
180 return (check_cur_key(
"num_feature") ||
181 assign_value(
"num_nodes", std::atoi(str), output) ||
182 check_cur_key(
"size_leaf_vector") || check_cur_key(
"num_deleted"));
188 bool RegTreeHandler::StartArray() {
192 push_key_handler<ArrayHandler<double>>(
"loss_changes", loss_changes) ||
193 push_key_handler<ArrayHandler<double>>(
"sum_hessian", sum_hessian) ||
194 push_key_handler<ArrayHandler<double>>(
"base_weights", base_weights) ||
195 push_key_handler<ArrayHandler<int>>(
"categories_segments", categories_segments) ||
196 push_key_handler<ArrayHandler<int>>(
"categories_sizes", categories_sizes) ||
197 push_key_handler<ArrayHandler<int>>(
"categories_nodes", categories_nodes) ||
198 push_key_handler<ArrayHandler<int>>(
"categories", categories) ||
199 push_key_handler<IgnoreHandler>(
"leaf_child_counts") ||
200 push_key_handler<ArrayHandler<int>>(
"left_children", left_children) ||
201 push_key_handler<ArrayHandler<int>>(
"right_children", right_children) ||
202 push_key_handler<ArrayHandler<int>>(
"parents", parents) ||
203 push_key_handler<ArrayHandler<int>>(
"split_indices", split_indices) ||
204 push_key_handler<ArrayHandler<int>>(
"split_type", split_type) ||
205 push_key_handler<ArrayHandler<double>>(
"split_conditions", split_conditions) ||
206 push_key_handler<ArrayHandler<bool>>(
"default_left", default_left));
209 bool RegTreeHandler::StartObject() {
210 return push_key_handler<TreeParamHandler, int>(
"tree_param", num_nodes);
213 bool RegTreeHandler::Uint(
unsigned) {
return check_cur_key(
"id"); }
215 bool RegTreeHandler::EndObject(std::size_t) {
217 if (split_type.empty()) {
218 split_type.resize(num_nodes, details::xgboost::FeatureType::kNumerical);
220 if (static_cast<size_t>(num_nodes) != loss_changes.size()) {
221 LOG(ERROR) <<
"Field loss_changes has an incorrect dimension. Expected: " << num_nodes
222 <<
", Actual: " << loss_changes.size();
225 if (static_cast<size_t>(num_nodes) != sum_hessian.size()) {
226 LOG(ERROR) <<
"Field sum_hessian has an incorrect dimension. Expected: " << num_nodes
227 <<
", Actual: " << sum_hessian.size();
230 if (static_cast<size_t>(num_nodes) != base_weights.size()) {
231 LOG(ERROR) <<
"Field base_weights has an incorrect dimension. Expected: " << num_nodes
232 <<
", Actual: " << base_weights.size();
235 if (static_cast<size_t>(num_nodes) != left_children.size()) {
236 LOG(ERROR) <<
"Field left_children has an incorrect dimension. Expected: " << num_nodes
237 <<
", Actual: " << left_children.size();
240 if (static_cast<size_t>(num_nodes) != right_children.size()) {
241 LOG(ERROR) <<
"Field right_children has an incorrect dimension. Expected: " << num_nodes
242 <<
", Actual: " << right_children.size();
245 if (static_cast<size_t>(num_nodes) != parents.size()) {
246 LOG(ERROR) <<
"Field parents has an incorrect dimension. Expected: " << num_nodes
247 <<
", Actual: " << parents.size();
250 if (static_cast<size_t>(num_nodes) != split_indices.size()) {
251 LOG(ERROR) <<
"Field split_indices has an incorrect dimension. Expected: " << num_nodes
252 <<
", Actual: " << split_indices.size();
255 if (static_cast<size_t>(num_nodes) != split_type.size()) {
256 LOG(ERROR) <<
"Field split_type has an incorrect dimension. Expected: " << num_nodes
257 <<
", Actual: " << split_type.size();
260 if (static_cast<size_t>(num_nodes) != split_conditions.size()) {
261 LOG(ERROR) <<
"Field split_conditions has an incorrect dimension. Expected: " << num_nodes
262 <<
", Actual: " << split_conditions.size();
265 if (static_cast<size_t>(num_nodes) != default_left.size()) {
266 LOG(ERROR) <<
"Field default_left has an incorrect dimension. Expected: " << num_nodes
267 <<
", Actual: " << default_left.size();
271 std::queue<std::pair<int, int>> Q;
277 std::tie(old_id, new_id) = Q.front();
280 if (left_children[old_id] == -1) {
281 output.SetLeaf(new_id, split_conditions[old_id]);
283 output.AddChilds(new_id);
284 if (split_type[old_id] == details::xgboost::FeatureType::kCategorical) {
285 auto categorical_split_loc
286 = math::binary_search(categories_nodes.begin(), categories_nodes.end(), old_id);
287 CHECK(categorical_split_loc != categories_nodes.end())
288 <<
"Could not find record for the categorical split in node " << old_id;
289 auto categorical_split_id = std::distance(categories_nodes.begin(), categorical_split_loc);
290 int offset = categories_segments[categorical_split_id];
291 int num_categories = categories_sizes[categorical_split_id];
292 std::vector<uint32_t> right_categories;
293 right_categories.reserve(num_categories);
294 for (
int i = 0; i < num_categories; ++i) {
295 right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
297 output.SetCategoricalSplit(
298 new_id, split_indices[old_id], default_left[old_id], right_categories,
true);
300 output.SetNumericalSplit(
301 new_id, split_indices[old_id], split_conditions[old_id],
302 default_left[old_id], treelite::Operator::kLT);
304 output.SetGain(new_id, loss_changes[old_id]);
305 Q.push({left_children[old_id], output.LeftChild(new_id)});
306 Q.push({right_children[old_id], output.RightChild(new_id)});
308 output.SetSumHess(new_id, sum_hessian[old_id]);
310 return pop_handler();
316 bool GBTreeModelHandler::StartArray() {
319 "trees", output.trees) ||
320 push_key_handler<IgnoreHandler>(
"tree_info"));
323 bool GBTreeModelHandler::StartObject() {
324 return push_key_handler<IgnoreHandler>(
"gbtree_model_param");
330 bool GradientBoosterHandler::String(
const char *str,
333 if (assign_value(
"name", std::string{str, length}, name)) {
334 if (name ==
"gbtree" || name ==
"dart") {
337 LOG(ERROR) <<
"Only GBTree or DART boosters are currently supported.";
344 bool GradientBoosterHandler::StartObject() {
352 LOG(ERROR) <<
"Key \"" << get_cur_key()
353 <<
"\" not recognized. Is this a GBTree-type booster?";
357 bool GradientBoosterHandler::StartArray() {
358 return push_key_handler<ArrayHandler<double>, std::vector<double>>(
"weight_drop", weight_drop);
360 bool GradientBoosterHandler::EndObject(std::size_t memberCount) {
361 if (name ==
"dart" && !weight_drop.empty()) {
363 CHECK_EQ(output.trees.size(), weight_drop.size());
364 for (
size_t i = 0; i < output.trees.size(); ++i) {
365 for (
int nid = 0; nid < output.trees[i].num_nodes; ++nid) {
366 if (output.trees[i].IsLeaf(nid)) {
367 output.trees[i].SetLeaf(nid, weight_drop[i] * output.trees[i].LeafValue(nid));
372 return pop_handler();
378 bool ObjectiveHandler::StartObject() {
379 return (push_key_handler<IgnoreHandler>(
"reg_loss_param") ||
380 push_key_handler<IgnoreHandler>(
"poisson_regression_param") ||
381 push_key_handler<IgnoreHandler>(
"tweedie_regression_param") ||
382 push_key_handler<IgnoreHandler>(
"softmax_multiclass_param") ||
383 push_key_handler<IgnoreHandler>(
"lambda_rank_param") ||
384 push_key_handler<IgnoreHandler>(
"aft_loss_param"));
387 bool ObjectiveHandler::String(
const char *str, std::size_t length,
bool) {
388 return assign_value(
"name", std::string{str, length}, output);
394 bool LearnerParamHandler::String(
const char *str,
397 return (assign_value(
"base_score", strtof(str,
nullptr),
398 output.param.global_bias) ||
399 assign_value(
"num_class", static_cast<unsigned int>(std::max(std::atoi(str), 1)),
400 output.task_param.num_class) ||
401 assign_value(
"num_feature", std::atoi(str), output.num_feature));
407 bool LearnerHandler::StartObject() {
410 "learner_model_param", *output.model) ||
412 "gradient_booster", *output.model) ||
413 push_key_handler<ObjectiveHandler, std::string>(
"objective", objective) ||
414 push_key_handler<IgnoreHandler>(
"attributes"));
417 bool LearnerHandler::EndObject(std::size_t) {
418 xgboost::SetPredTransform(objective, &output.model->param);
419 output.objective_name = objective;
420 return pop_handler();
423 bool LearnerHandler::StartArray() {
424 return (push_key_handler<IgnoreHandler>(
"feature_names") ||
425 push_key_handler<IgnoreHandler>(
"feature_types"));
431 bool XGBoostModelHandler::StartArray() {
432 return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
436 bool XGBoostModelHandler::StartObject() {
437 return push_key_handler<LearnerHandler, XGBoostModelHandle>(
"learner", output);
440 bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
441 if (memberCount != 2) {
442 LOG(ERROR) <<
"Expected two members in XGBoostModel";
445 output.model->average_tree_output =
false;
446 output.model->task_param.output_type = TaskParameter::OutputType::kFloat;
447 output.model->task_param.leaf_vector_size = 1;
448 if (output.model->task_param.num_class > 1) {
450 output.model->task_type = TaskType::kMultiClfGrovePerClass;
451 output.model->task_param.grove_per_class =
true;
454 output.model->task_type = TaskType::kBinaryClfRegr;
455 output.model->task_param.grove_per_class =
false;
459 const bool need_transform_to_margin = (version[0] >= 1);
460 if (need_transform_to_margin) {
461 treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
463 return pop_handler();
469 bool RootHandler::StartObject() {
471 return push_handler<XGBoostModelHandler, XGBoostModelHandle>(handle);
477 std::unique_ptr<treelite::Model> DelegatedHandler::get_result() {
return std::move(result); }
478 bool DelegatedHandler::Null() {
return delegates.top()->Null(); }
479 bool DelegatedHandler::Bool(
bool b) {
return delegates.top()->Bool(b); }
480 bool DelegatedHandler::Int(
int i) {
return delegates.top()->Int(i); }
481 bool DelegatedHandler::Uint(
unsigned u) {
return delegates.top()->Uint(u); }
482 bool DelegatedHandler::Int64(int64_t i) {
return delegates.top()->Int64(i); }
483 bool DelegatedHandler::Uint64(uint64_t u) {
return delegates.top()->Uint64(u); }
484 bool DelegatedHandler::Double(
double d) {
return delegates.top()->Double(d); }
485 bool DelegatedHandler::String(
const char *str, std::size_t length,
bool copy) {
486 return delegates.top()->String(str, length, copy);
488 bool DelegatedHandler::StartObject() {
return delegates.top()->StartObject(); }
489 bool DelegatedHandler::Key(
const char *str, std::size_t length,
bool copy) {
490 return delegates.top()->Key(str, length, copy);
492 bool DelegatedHandler::EndObject(std::size_t memberCount) {
493 return delegates.top()->EndObject(memberCount);
495 bool DelegatedHandler::StartArray() {
return delegates.top()->StartArray(); }
496 bool DelegatedHandler::EndArray(std::size_t elementCount) {
497 return delegates.top()->EndArray(elementCount);
505 template <
typename StreamType,
typename ErrorHandlerFunc>
506 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
507 ErrorHandlerFunc error_handler) {
508 std::shared_ptr<treelite::details::DelegatedHandler> handler =
510 rapidjson::Reader reader;
512 rapidjson::ParseResult result
513 = reader.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(*input_stream, *handler);
515 const auto error_code = result.Code();
516 const size_t offset = result.Offset();
517 std::string diagnostic = error_handler(offset);
518 LOG(FATAL) <<
"Provided JSON could not be parsed as XGBoost model. Parsing error at offset " 519 << offset <<
": " << rapidjson::GetParseError_En(error_code) <<
"\n" 522 return handler->get_result();
Some useful math utilities.
Collection of front-end methods to load or construct ensemble model.
model structure for tree ensemble
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Methods for loading models from XGBoost-style JSON.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Helper functions for loading XGBoost models.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...