11 #include <fmt/format.h> 12 #include <rapidjson/error/en.h> 13 #include <rapidjson/document.h> 14 #include <rapidjson/filereadstream.h> 33 template <
typename StreamType,
typename ErrorHandlerFunc>
34 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
35 ErrorHandlerFunc error_handler);
43 char read_buffer[65536];
46 FILE* fp = std::fopen(filename,
"rb");
48 FILE* fp = std::fopen(filename,
"r");
51 TREELITE_LOG(FATAL) <<
"Failed to open file '" << filename <<
"': " << std::strerror(errno);
54 auto input_stream = std::make_unique<rapidjson::FileReadStream>(
55 fp, read_buffer,
sizeof(read_buffer));
56 auto error_handler = [fp](
size_t offset) -> std::string {
57 size_t cur = (offset >= 50 ? (offset - 50) : 0);
58 std::fseek(fp, cur, SEEK_SET);
60 std::ostringstream oss, oss2;
61 for (
int i = 0; i < 100; ++i) {
66 oss << static_cast<char>(c);
75 return oss.str() +
"\n" + oss2.str();
77 auto parsed_model = ParseStream(std::move(input_stream), error_handler);
83 auto input_stream = std::make_unique<rapidjson::MemoryStream>(json_str, length);
84 auto error_handler = [json_str](
size_t offset) -> std::string {
85 size_t cur = (offset >= 50 ? (offset - 50) : 0);
86 std::ostringstream oss, oss2;
87 for (
int i = 0; i < 100; ++i) {
99 return oss.str() +
"\n" + oss2.str();
101 return ParseStream(std::move(input_stream), error_handler);
112 bool BaseHandler::pop_handler() {
113 if (
auto parent = delegator.lock()) {
114 parent->pop_delegate();
121 void BaseHandler::set_cur_key(
const char *str, std::size_t length) {
122 cur_key = std::string{str, length};
125 const std::string &BaseHandler::get_cur_key() {
return cur_key; }
127 bool BaseHandler::check_cur_key(
const std::string &query_key) {
128 return cur_key == query_key;
131 template <
typename ValueType>
132 bool BaseHandler::assign_value(
const std::string &key,
135 if (check_cur_key(key)) {
143 template <
typename ValueType>
144 bool BaseHandler::assign_value(
const std::string &key,
145 const ValueType &value,
147 if (check_cur_key(key)) {
158 bool IgnoreHandler::Null() {
return true; }
159 bool IgnoreHandler::Bool(
bool) {
return true; }
160 bool IgnoreHandler::Int(
int) {
return true; }
161 bool IgnoreHandler::Uint(
unsigned) {
return true; }
162 bool IgnoreHandler::Int64(int64_t) {
return true; }
163 bool IgnoreHandler::Uint64(uint64_t) {
return true; }
164 bool IgnoreHandler::Double(
double) {
return true; }
165 bool IgnoreHandler::String(
const char *, std::size_t,
bool) {
167 bool IgnoreHandler::StartObject() {
return push_handler<IgnoreHandler>(); }
168 bool IgnoreHandler::Key(
const char *, std::size_t,
bool) {
170 bool IgnoreHandler::StartArray() {
return push_handler<IgnoreHandler>(); }
175 bool TreeParamHandler::String(
const char *str, std::size_t,
bool) {
177 return (check_cur_key(
"num_feature") ||
178 assign_value(
"num_nodes", std::atoi(str), output) ||
179 check_cur_key(
"size_leaf_vector") || check_cur_key(
"num_deleted"));
185 bool RegTreeHandler::StartArray() {
189 push_key_handler<ArrayHandler<double>>(
"loss_changes", loss_changes) ||
190 push_key_handler<ArrayHandler<double>>(
"sum_hessian", sum_hessian) ||
191 push_key_handler<ArrayHandler<double>>(
"base_weights", base_weights) ||
192 push_key_handler<ArrayHandler<int>>(
"categories_segments", categories_segments) ||
193 push_key_handler<ArrayHandler<int>>(
"categories_sizes", categories_sizes) ||
194 push_key_handler<ArrayHandler<int>>(
"categories_nodes", categories_nodes) ||
195 push_key_handler<ArrayHandler<int>>(
"categories", categories) ||
196 push_key_handler<IgnoreHandler>(
"leaf_child_counts") ||
197 push_key_handler<ArrayHandler<int>>(
"left_children", left_children) ||
198 push_key_handler<ArrayHandler<int>>(
"right_children", right_children) ||
199 push_key_handler<ArrayHandler<int>>(
"parents", parents) ||
200 push_key_handler<ArrayHandler<int>>(
"split_indices", split_indices) ||
201 push_key_handler<ArrayHandler<int>>(
"split_type", split_type) ||
202 push_key_handler<ArrayHandler<double>>(
"split_conditions", split_conditions) ||
203 push_key_handler<ArrayHandler<bool>>(
"default_left", default_left));
206 bool RegTreeHandler::StartObject() {
207 return push_key_handler<TreeParamHandler, int>(
"tree_param", num_nodes);
210 bool RegTreeHandler::Uint(
unsigned) {
return check_cur_key(
"id"); }
212 bool RegTreeHandler::EndObject(std::size_t) {
214 if (split_type.empty()) {
215 split_type.resize(num_nodes, details::xgboost::FeatureType::kNumerical);
217 if (static_cast<size_t>(num_nodes) != loss_changes.size()) {
218 TREELITE_LOG(ERROR) <<
"Field loss_changes has an incorrect dimension. Expected: " << num_nodes
219 <<
", Actual: " << loss_changes.size();
222 if (static_cast<size_t>(num_nodes) != sum_hessian.size()) {
223 TREELITE_LOG(ERROR) <<
"Field sum_hessian has an incorrect dimension. Expected: " << num_nodes
224 <<
", Actual: " << sum_hessian.size();
227 if (static_cast<size_t>(num_nodes) != base_weights.size()) {
228 TREELITE_LOG(ERROR) <<
"Field base_weights has an incorrect dimension. Expected: " << num_nodes
229 <<
", Actual: " << base_weights.size();
232 if (static_cast<size_t>(num_nodes) != left_children.size()) {
233 TREELITE_LOG(ERROR) <<
"Field left_children has an incorrect dimension. Expected: " << num_nodes
234 <<
", Actual: " << left_children.size();
237 if (static_cast<size_t>(num_nodes) != right_children.size()) {
238 TREELITE_LOG(ERROR) <<
"Field right_children has an incorrect dimension. Expected: " 239 << num_nodes <<
", Actual: " << right_children.size();
242 if (static_cast<size_t>(num_nodes) != parents.size()) {
243 TREELITE_LOG(ERROR) <<
"Field parents has an incorrect dimension. Expected: " << num_nodes
244 <<
", Actual: " << parents.size();
247 if (static_cast<size_t>(num_nodes) != split_indices.size()) {
248 TREELITE_LOG(ERROR) <<
"Field split_indices has an incorrect dimension. Expected: " << num_nodes
249 <<
", Actual: " << split_indices.size();
252 if (static_cast<size_t>(num_nodes) != split_type.size()) {
253 TREELITE_LOG(ERROR) <<
"Field split_type has an incorrect dimension. Expected: " << num_nodes
254 <<
", Actual: " << split_type.size();
257 if (static_cast<size_t>(num_nodes) != split_conditions.size()) {
258 TREELITE_LOG(ERROR) <<
"Field split_conditions has an incorrect dimension. Expected: " 259 << num_nodes <<
", Actual: " << split_conditions.size();
262 if (static_cast<size_t>(num_nodes) != default_left.size()) {
263 TREELITE_LOG(ERROR) <<
"Field default_left has an incorrect dimension. Expected: " << num_nodes
264 <<
", Actual: " << default_left.size();
268 std::queue<std::pair<int, int>> Q;
274 std::tie(old_id, new_id) = Q.front();
277 if (left_children[old_id] == -1) {
278 output.SetLeaf(new_id, split_conditions[old_id]);
280 output.AddChilds(new_id);
281 if (split_type[old_id] == details::xgboost::FeatureType::kCategorical) {
282 auto categorical_split_loc
283 = math::binary_search(categories_nodes.begin(), categories_nodes.end(), old_id);
284 TREELITE_CHECK(categorical_split_loc != categories_nodes.end())
285 <<
"Could not find record for the categorical split in node " << old_id;
286 auto categorical_split_id = std::distance(categories_nodes.begin(), categorical_split_loc);
287 int offset = categories_segments[categorical_split_id];
288 int num_categories = categories_sizes[categorical_split_id];
289 std::vector<uint32_t> right_categories;
290 right_categories.reserve(num_categories);
291 for (
int i = 0; i < num_categories; ++i) {
292 right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
294 output.SetCategoricalSplit(
295 new_id, split_indices[old_id], default_left[old_id], right_categories,
true);
297 output.SetNumericalSplit(
298 new_id, split_indices[old_id], split_conditions[old_id],
299 default_left[old_id], treelite::Operator::kLT);
301 output.SetGain(new_id, loss_changes[old_id]);
302 Q.push({left_children[old_id], output.LeftChild(new_id)});
303 Q.push({right_children[old_id], output.RightChild(new_id)});
305 output.SetSumHess(new_id, sum_hessian[old_id]);
307 return pop_handler();
313 bool GBTreeModelHandler::StartArray() {
316 "trees", output.trees) ||
317 push_key_handler<IgnoreHandler>(
"tree_info"));
320 bool GBTreeModelHandler::StartObject() {
321 return push_key_handler<IgnoreHandler>(
"gbtree_model_param");
327 bool GradientBoosterHandler::String(
const char *str,
330 if (assign_value(
"name", std::string{str, length}, name)) {
331 if (name ==
"gbtree" || name ==
"dart") {
334 TREELITE_LOG(ERROR) <<
"Only GBTree or DART boosters are currently supported.";
341 bool GradientBoosterHandler::StartObject() {
349 TREELITE_LOG(ERROR) <<
"Key \"" << get_cur_key()
350 <<
"\" not recognized. Is this a GBTree-type booster?";
354 bool GradientBoosterHandler::StartArray() {
355 return push_key_handler<ArrayHandler<double>, std::vector<double>>(
"weight_drop", weight_drop);
357 bool GradientBoosterHandler::EndObject(std::size_t memberCount) {
358 if (name ==
"dart" && !weight_drop.empty()) {
360 TREELITE_CHECK_EQ(output.trees.size(), weight_drop.size());
361 for (
size_t i = 0; i < output.trees.size(); ++i) {
362 for (
int nid = 0; nid < output.trees[i].num_nodes; ++nid) {
363 if (output.trees[i].IsLeaf(nid)) {
364 output.trees[i].SetLeaf(nid, weight_drop[i] * output.trees[i].LeafValue(nid));
369 return pop_handler();
375 bool ObjectiveHandler::StartObject() {
376 return (push_key_handler<IgnoreHandler>(
"reg_loss_param") ||
377 push_key_handler<IgnoreHandler>(
"poisson_regression_param") ||
378 push_key_handler<IgnoreHandler>(
"tweedie_regression_param") ||
379 push_key_handler<IgnoreHandler>(
"softmax_multiclass_param") ||
380 push_key_handler<IgnoreHandler>(
"lambda_rank_param") ||
381 push_key_handler<IgnoreHandler>(
"aft_loss_param"));
384 bool ObjectiveHandler::String(
const char *str, std::size_t length,
bool) {
385 return assign_value(
"name", std::string{str, length}, output);
391 bool LearnerParamHandler::String(
const char *str,
394 return (assign_value(
"base_score", strtof(str,
nullptr),
395 output.param.global_bias) ||
396 assign_value(
"num_class", static_cast<unsigned int>(std::max(std::atoi(str), 1)),
397 output.task_param.num_class) ||
398 assign_value(
"num_feature", std::atoi(str), output.num_feature));
404 bool LearnerHandler::StartObject() {
407 "learner_model_param", *output.model) ||
409 "gradient_booster", *output.model) ||
410 push_key_handler<ObjectiveHandler, std::string>(
"objective", objective) ||
411 push_key_handler<IgnoreHandler>(
"attributes"));
414 bool LearnerHandler::EndObject(std::size_t) {
415 xgboost::SetPredTransform(objective, &output.model->param);
416 output.objective_name = objective;
417 return pop_handler();
420 bool LearnerHandler::StartArray() {
421 return (push_key_handler<IgnoreHandler>(
"feature_names") ||
422 push_key_handler<IgnoreHandler>(
"feature_types"));
428 bool XGBoostModelHandler::StartArray() {
429 return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
433 bool XGBoostModelHandler::StartObject() {
434 return push_key_handler<LearnerHandler, XGBoostModelHandle>(
"learner", output);
437 bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
438 if (memberCount != 2) {
439 TREELITE_LOG(ERROR) <<
"Expected two members in XGBoostModel";
442 output.model->average_tree_output =
false;
443 output.model->task_param.output_type = TaskParam::OutputType::kFloat;
444 output.model->task_param.leaf_vector_size = 1;
445 if (output.model->task_param.num_class > 1) {
447 output.model->task_type = TaskType::kMultiClfGrovePerClass;
448 output.model->task_param.grove_per_class =
true;
451 output.model->task_type = TaskType::kBinaryClfRegr;
452 output.model->task_param.grove_per_class =
false;
456 const bool need_transform_to_margin = (version[0] >= 1);
457 if (need_transform_to_margin) {
458 treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
460 return pop_handler();
466 bool RootHandler::StartObject() {
468 return push_handler<XGBoostModelHandler, XGBoostModelHandle>(handle);
474 std::unique_ptr<treelite::Model> DelegatedHandler::get_result() {
return std::move(result); }
475 bool DelegatedHandler::Null() {
return delegates.top()->Null(); }
476 bool DelegatedHandler::Bool(
bool b) {
return delegates.top()->Bool(b); }
477 bool DelegatedHandler::Int(
int i) {
return delegates.top()->Int(i); }
478 bool DelegatedHandler::Uint(
unsigned u) {
return delegates.top()->Uint(u); }
479 bool DelegatedHandler::Int64(int64_t i) {
return delegates.top()->Int64(i); }
480 bool DelegatedHandler::Uint64(uint64_t u) {
return delegates.top()->Uint64(u); }
481 bool DelegatedHandler::Double(
double d) {
return delegates.top()->Double(d); }
482 bool DelegatedHandler::String(
const char *str, std::size_t length,
bool copy) {
483 return delegates.top()->String(str, length, copy);
485 bool DelegatedHandler::StartObject() {
return delegates.top()->StartObject(); }
486 bool DelegatedHandler::Key(
const char *str, std::size_t length,
bool copy) {
487 return delegates.top()->Key(str, length, copy);
489 bool DelegatedHandler::EndObject(std::size_t memberCount) {
490 return delegates.top()->EndObject(memberCount);
492 bool DelegatedHandler::StartArray() {
return delegates.top()->StartArray(); }
493 bool DelegatedHandler::EndArray(std::size_t elementCount) {
494 return delegates.top()->EndArray(elementCount);
502 template <
typename StreamType,
typename ErrorHandlerFunc>
503 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
504 ErrorHandlerFunc error_handler) {
505 std::shared_ptr<treelite::details::DelegatedHandler> handler =
507 rapidjson::Reader reader;
509 rapidjson::ParseResult result
510 = reader.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(*input_stream, *handler);
512 const auto error_code = result.Code();
513 const size_t offset = result.Offset();
514 std::string diagnostic = error_handler(offset);
515 TREELITE_LOG(FATAL) <<
"Provided JSON could not be parsed as XGBoost model. " 516 <<
"Parsing error at offset " << offset <<
": " 517 << rapidjson::GetParseError_En(error_code) <<
"\n" << diagnostic;
519 return handler->get_result();
Some useful math utilities.
Collection of front-end methods to load or construct ensemble model.
model structure for tree ensemble
logging facility for Treelite
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Methods for loading models from XGBoost-style JSON.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Helper functions for loading XGBoost models.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...