11 #include <rapidjson/error/en.h> 12 #include <rapidjson/document.h> 13 #include <rapidjson/filereadstream.h> 32 template <
typename StreamType,
typename ErrorHandlerFunc>
33 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
34 ErrorHandlerFunc error_handler);
42 char read_buffer[65536];
45 FILE* fp = std::fopen(filename,
"rb");
47 FILE* fp = std::fopen(filename,
"r");
50 TREELITE_LOG(FATAL) <<
"Failed to open file '" << filename <<
"': " << std::strerror(errno);
53 auto input_stream = std::make_unique<rapidjson::FileReadStream>(
54 fp, read_buffer,
sizeof(read_buffer));
55 auto error_handler = [fp](
size_t offset) -> std::string {
56 size_t cur = (offset >= 50 ? (offset - 50) : 0);
57 std::fseek(fp, cur, SEEK_SET);
59 std::ostringstream oss, oss2;
60 for (
int i = 0; i < 100; ++i) {
65 oss << static_cast<char>(c);
74 return oss.str() +
"\n" + oss2.str();
76 auto parsed_model = ParseStream(std::move(input_stream), error_handler);
82 auto input_stream = std::make_unique<rapidjson::MemoryStream>(json_str, length);
83 auto error_handler = [json_str](
size_t offset) -> std::string {
84 size_t cur = (offset >= 50 ? (offset - 50) : 0);
85 std::ostringstream oss, oss2;
86 for (
int i = 0; i < 100; ++i) {
98 return oss.str() +
"\n" + oss2.str();
100 return ParseStream(std::move(input_stream), error_handler);
111 bool BaseHandler::pop_handler() {
112 if (
auto parent = delegator.lock()) {
113 parent->pop_delegate();
120 void BaseHandler::set_cur_key(
const char *str, std::size_t length) {
121 cur_key = std::string{str, length};
124 const std::string &BaseHandler::get_cur_key() {
return cur_key; }
126 bool BaseHandler::check_cur_key(
const std::string &query_key) {
127 return cur_key == query_key;
130 template <
typename ValueType>
131 bool BaseHandler::assign_value(
const std::string &key,
134 if (check_cur_key(key)) {
142 template <
typename ValueType>
143 bool BaseHandler::assign_value(
const std::string &key,
144 const ValueType &value,
146 if (check_cur_key(key)) {
157 bool IgnoreHandler::Null() {
return true; }
158 bool IgnoreHandler::Bool(
bool) {
return true; }
159 bool IgnoreHandler::Int(
int) {
return true; }
160 bool IgnoreHandler::Uint(
unsigned) {
return true; }
161 bool IgnoreHandler::Int64(int64_t) {
return true; }
162 bool IgnoreHandler::Uint64(uint64_t) {
return true; }
163 bool IgnoreHandler::Double(
double) {
return true; }
164 bool IgnoreHandler::String(
const char *, std::size_t,
bool) {
166 bool IgnoreHandler::StartObject() {
return push_handler<IgnoreHandler>(); }
167 bool IgnoreHandler::Key(
const char *, std::size_t,
bool) {
169 bool IgnoreHandler::StartArray() {
return push_handler<IgnoreHandler>(); }
174 bool TreeParamHandler::String(
const char *str, std::size_t,
bool) {
176 return (check_cur_key(
"num_feature") ||
177 assign_value(
"num_nodes", std::atoi(str), output) ||
178 check_cur_key(
"size_leaf_vector") || check_cur_key(
"num_deleted"));
184 bool RegTreeHandler::StartArray() {
188 push_key_handler<ArrayHandler<double>>(
"loss_changes", loss_changes) ||
189 push_key_handler<ArrayHandler<double>>(
"sum_hessian", sum_hessian) ||
190 push_key_handler<ArrayHandler<double>>(
"base_weights", base_weights) ||
191 push_key_handler<ArrayHandler<int>>(
"categories_segments", categories_segments) ||
192 push_key_handler<ArrayHandler<int>>(
"categories_sizes", categories_sizes) ||
193 push_key_handler<ArrayHandler<int>>(
"categories_nodes", categories_nodes) ||
194 push_key_handler<ArrayHandler<int>>(
"categories", categories) ||
195 push_key_handler<IgnoreHandler>(
"leaf_child_counts") ||
196 push_key_handler<ArrayHandler<int>>(
"left_children", left_children) ||
197 push_key_handler<ArrayHandler<int>>(
"right_children", right_children) ||
198 push_key_handler<ArrayHandler<int>>(
"parents", parents) ||
199 push_key_handler<ArrayHandler<int>>(
"split_indices", split_indices) ||
200 push_key_handler<ArrayHandler<int>>(
"split_type", split_type) ||
201 push_key_handler<ArrayHandler<double>>(
"split_conditions", split_conditions) ||
202 push_key_handler<ArrayHandler<bool>>(
"default_left", default_left));
205 bool RegTreeHandler::StartObject() {
206 return push_key_handler<TreeParamHandler, int>(
"tree_param", num_nodes);
209 bool RegTreeHandler::Uint(
unsigned) {
return check_cur_key(
"id"); }
211 bool RegTreeHandler::EndObject(std::size_t) {
213 if (split_type.empty()) {
214 split_type.resize(num_nodes, details::xgboost::FeatureType::kNumerical);
216 if (static_cast<size_t>(num_nodes) != loss_changes.size()) {
217 TREELITE_LOG(ERROR) <<
"Field loss_changes has an incorrect dimension. Expected: " << num_nodes
218 <<
", Actual: " << loss_changes.size();
221 if (static_cast<size_t>(num_nodes) != sum_hessian.size()) {
222 TREELITE_LOG(ERROR) <<
"Field sum_hessian has an incorrect dimension. Expected: " << num_nodes
223 <<
", Actual: " << sum_hessian.size();
226 if (static_cast<size_t>(num_nodes) != base_weights.size()) {
227 TREELITE_LOG(ERROR) <<
"Field base_weights has an incorrect dimension. Expected: " << num_nodes
228 <<
", Actual: " << base_weights.size();
231 if (static_cast<size_t>(num_nodes) != left_children.size()) {
232 TREELITE_LOG(ERROR) <<
"Field left_children has an incorrect dimension. Expected: " << num_nodes
233 <<
", Actual: " << left_children.size();
236 if (static_cast<size_t>(num_nodes) != right_children.size()) {
237 TREELITE_LOG(ERROR) <<
"Field right_children has an incorrect dimension. Expected: " 238 << num_nodes <<
", Actual: " << right_children.size();
241 if (static_cast<size_t>(num_nodes) != parents.size()) {
242 TREELITE_LOG(ERROR) <<
"Field parents has an incorrect dimension. Expected: " << num_nodes
243 <<
", Actual: " << parents.size();
246 if (static_cast<size_t>(num_nodes) != split_indices.size()) {
247 TREELITE_LOG(ERROR) <<
"Field split_indices has an incorrect dimension. Expected: " << num_nodes
248 <<
", Actual: " << split_indices.size();
251 if (static_cast<size_t>(num_nodes) != split_type.size()) {
252 TREELITE_LOG(ERROR) <<
"Field split_type has an incorrect dimension. Expected: " << num_nodes
253 <<
", Actual: " << split_type.size();
256 if (static_cast<size_t>(num_nodes) != split_conditions.size()) {
257 TREELITE_LOG(ERROR) <<
"Field split_conditions has an incorrect dimension. Expected: " 258 << num_nodes <<
", Actual: " << split_conditions.size();
261 if (static_cast<size_t>(num_nodes) != default_left.size()) {
262 TREELITE_LOG(ERROR) <<
"Field default_left has an incorrect dimension. Expected: " << num_nodes
263 <<
", Actual: " << default_left.size();
267 std::queue<std::pair<int, int>> Q;
273 std::tie(old_id, new_id) = Q.front();
276 if (left_children[old_id] == -1) {
277 output.SetLeaf(new_id, split_conditions[old_id]);
279 output.AddChilds(new_id);
280 if (split_type[old_id] == details::xgboost::FeatureType::kCategorical) {
281 auto categorical_split_loc
282 = math::binary_search(categories_nodes.begin(), categories_nodes.end(), old_id);
283 TREELITE_CHECK(categorical_split_loc != categories_nodes.end())
284 <<
"Could not find record for the categorical split in node " << old_id;
285 auto categorical_split_id = std::distance(categories_nodes.begin(), categorical_split_loc);
286 int offset = categories_segments[categorical_split_id];
287 int num_categories = categories_sizes[categorical_split_id];
288 std::vector<uint32_t> right_categories;
289 right_categories.reserve(num_categories);
290 for (
int i = 0; i < num_categories; ++i) {
291 right_categories.push_back(static_cast<uint32_t>(categories[offset + i]));
293 output.SetCategoricalSplit(
294 new_id, split_indices[old_id], default_left[old_id], right_categories,
true);
296 output.SetNumericalSplit(
297 new_id, split_indices[old_id], split_conditions[old_id],
298 default_left[old_id], treelite::Operator::kLT);
300 output.SetGain(new_id, loss_changes[old_id]);
301 Q.push({left_children[old_id], output.LeftChild(new_id)});
302 Q.push({right_children[old_id], output.RightChild(new_id)});
304 output.SetSumHess(new_id, sum_hessian[old_id]);
306 return pop_handler();
312 bool GBTreeModelHandler::StartArray() {
315 "trees", output.model->trees) ||
316 push_key_handler<ArrayHandler<int>, std::vector<int>>(
"tree_info", output.tree_info));
319 bool GBTreeModelHandler::StartObject() {
320 return push_key_handler<IgnoreHandler>(
"gbtree_model_param");
326 bool GradientBoosterHandler::String(
const char *str,
329 if (assign_value(
"name", std::string{str, length}, name)) {
330 if (name ==
"gbtree" || name ==
"dart") {
333 TREELITE_LOG(ERROR) <<
"Only GBTree or DART boosters are currently supported.";
340 bool GradientBoosterHandler::StartObject() {
341 if (push_key_handler<GBTreeModelHandler, ParsedXGBoostModel>(
"model", output)) {
343 }
else if (push_key_handler<GradientBoosterHandler, ParsedXGBoostModel>(
"gbtree", output)) {
347 TREELITE_LOG(ERROR) <<
"Key \"" << get_cur_key()
348 <<
"\" not recognized. Is this a GBTree-type booster?";
352 bool GradientBoosterHandler::StartArray() {
353 return push_key_handler<ArrayHandler<double>, std::vector<double>>(
"weight_drop", weight_drop);
355 bool GradientBoosterHandler::EndObject(std::size_t memberCount) {
356 if (name ==
"dart" && !weight_drop.empty()) {
358 auto& trees = output.model->trees;
359 TREELITE_CHECK_EQ(trees.size(), weight_drop.size());
360 for (
size_t i = 0; i < trees.size(); ++i) {
361 for (
int nid = 0; nid < trees[i].num_nodes; ++nid) {
362 if (trees[i].IsLeaf(nid)) {
363 trees[i].SetLeaf(nid, weight_drop[i] * trees[i].LeafValue(nid));
368 return pop_handler();
374 bool ObjectiveHandler::StartObject() {
375 return (push_key_handler<IgnoreHandler>(
"reg_loss_param") ||
376 push_key_handler<IgnoreHandler>(
"poisson_regression_param") ||
377 push_key_handler<IgnoreHandler>(
"tweedie_regression_param") ||
378 push_key_handler<IgnoreHandler>(
"softmax_multiclass_param") ||
379 push_key_handler<IgnoreHandler>(
"lambda_rank_param") ||
380 push_key_handler<IgnoreHandler>(
"aft_loss_param") ||
381 push_key_handler<IgnoreHandler>(
"pseduo_huber_param") ||
382 push_key_handler<IgnoreHandler>(
"pseudo_huber_param"));
385 bool ObjectiveHandler::String(
const char *str, std::size_t length,
bool) {
386 return assign_value(
"name", std::string{str, length}, output);
392 bool LearnerParamHandler::String(
const char *str,
396 if (assign_value(
"num_target", std::atoi(str), num_target)) {
397 if (num_target != 1) {
399 <<
"num_target must be 1; Treelite doesn't support multi-target regressor yet";
404 return (assign_value(
"base_score", strtof(str,
nullptr),
405 output.param.global_bias) ||
406 assign_value(
"num_class", static_cast<unsigned int>(std::max(std::atoi(str), 1)),
407 output.task_param.num_class) ||
408 assign_value(
"num_feature", std::atoi(str), output.num_feature) ||
409 check_cur_key(
"boost_from_average"));
415 bool LearnerHandler::StartObject() {
418 "learner_model_param", *output.model) ||
419 push_key_handler<GradientBoosterHandler, ParsedXGBoostModel>(
420 "gradient_booster", output) ||
421 push_key_handler<ObjectiveHandler, std::string>(
"objective", objective) ||
422 push_key_handler<IgnoreHandler>(
"attributes"));
425 bool LearnerHandler::EndObject(std::size_t) {
426 xgboost::SetPredTransform(objective, &output.model->param);
427 output.objective_name = objective;
428 return pop_handler();
431 bool LearnerHandler::StartArray() {
432 return (push_key_handler<IgnoreHandler>(
"feature_names") ||
433 push_key_handler<IgnoreHandler>(
"feature_types"));
440 bool XGBoostCheckpointHandler::StartArray() {
441 return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
442 "version", output.version);
445 bool XGBoostCheckpointHandler::StartObject() {
446 return push_key_handler<LearnerHandler, ParsedXGBoostModel>(
"learner", output);
452 bool XGBoostModelHandler::StartArray() {
453 return push_key_handler<ArrayHandler<unsigned>, std::vector<unsigned>>(
454 "version", output.version);
457 bool XGBoostModelHandler::StartObject() {
458 return (push_key_handler<LearnerHandler, ParsedXGBoostModel>(
"learner", output) ||
459 push_key_handler<IgnoreHandler>(
"Config") ||
460 push_key_handler<XGBoostCheckpointHandler, ParsedXGBoostModel>(
"Model", output));
463 bool XGBoostModelHandler::EndObject(std::size_t memberCount) {
464 if (memberCount != 2) {
465 TREELITE_LOG(ERROR) <<
"Expected two members in XGBoostModel";
468 output.model->average_tree_output =
false;
469 output.model->task_param.output_type = TaskParam::OutputType::kFloat;
470 output.model->task_param.leaf_vector_size = 1;
471 if (output.model->task_param.num_class > 1) {
473 output.model->task_type = TaskType::kMultiClfGrovePerClass;
474 output.model->task_param.grove_per_class =
true;
477 output.model->task_type = TaskType::kBinaryClfRegr;
478 output.model->task_param.grove_per_class =
false;
482 const bool need_transform_to_margin = (output.version[0] >= 1);
483 if (need_transform_to_margin) {
484 treelite::details::xgboost::TransformGlobalBiasToMargin(&output.model->param);
486 return pop_handler();
492 bool RootHandler::StartObject() {
493 return push_handler<XGBoostModelHandler, ParsedXGBoostModel>(output);
499 ParsedXGBoostModel DelegatedHandler::get_result() {
return std::move(result); }
500 bool DelegatedHandler::Null() {
return delegates.top()->Null(); }
501 bool DelegatedHandler::Bool(
bool b) {
return delegates.top()->Bool(b); }
502 bool DelegatedHandler::Int(
int i) {
return delegates.top()->Int(i); }
503 bool DelegatedHandler::Uint(
unsigned u) {
return delegates.top()->Uint(u); }
504 bool DelegatedHandler::Int64(int64_t i) {
return delegates.top()->Int64(i); }
505 bool DelegatedHandler::Uint64(uint64_t u) {
return delegates.top()->Uint64(u); }
506 bool DelegatedHandler::Double(
double d) {
return delegates.top()->Double(d); }
507 bool DelegatedHandler::String(
const char *str, std::size_t length,
bool copy) {
508 return delegates.top()->String(str, length, copy);
510 bool DelegatedHandler::StartObject() {
return delegates.top()->StartObject(); }
511 bool DelegatedHandler::Key(
const char *str, std::size_t length,
bool copy) {
512 return delegates.top()->Key(str, length, copy);
514 bool DelegatedHandler::EndObject(std::size_t memberCount) {
515 return delegates.top()->EndObject(memberCount);
517 bool DelegatedHandler::StartArray() {
return delegates.top()->StartArray(); }
518 bool DelegatedHandler::EndArray(std::size_t elementCount) {
519 return delegates.top()->EndArray(elementCount);
527 template <
typename StreamType,
typename ErrorHandlerFunc>
528 std::unique_ptr<treelite::Model> ParseStream(std::unique_ptr<StreamType> input_stream,
529 ErrorHandlerFunc error_handler) {
530 std::shared_ptr<treelite::details::DelegatedHandler> handler =
532 rapidjson::Reader reader;
534 rapidjson::ParseResult result
535 = reader.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(*input_stream, *handler);
537 const auto error_code = result.Code();
538 const size_t offset = result.Offset();
539 std::string diagnostic = error_handler(offset);
540 TREELITE_LOG(FATAL) <<
"Provided JSON could not be parsed as XGBoost model. " 541 <<
"Parsing error at offset " << offset <<
": " 542 << rapidjson::GetParseError_En(error_code) <<
"\n" << diagnostic;
549 unsigned num_parallel_tree = 0;
550 for (
int e : parsed.tree_info) {
556 if (num_parallel_tree > 1) {
564 std::vector<treelite::Tree<float, float>> new_trees;
565 std::size_t num_tree = parsed.model->
trees.size();
566 for (std::size_t c = 0; c < num_parallel_tree; ++c) {
567 for (std::size_t tree_id = c; tree_id < num_tree; tree_id += num_parallel_tree) {
568 new_trees.push_back(std::move(parsed.model->
trees[tree_id]));
571 TREELITE_CHECK_EQ(new_trees.size(), num_tree);
572 parsed.model->
trees = std::move(new_trees);
575 return std::move(parsed.model_ptr);
Some useful math utilities.
Collection of front-end methods to load or construct ensemble model.
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
model structure for tree ensemble
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
TaskParam task_param
Group of parameters that are specific to the particular task type.
Methods for loading models from XGBoost-style JSON.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModelString(const char *json_str, size_t length)
load an XGBoost model from a JSON string
Helper functions for loading XGBoost models.
std::unique_ptr< treelite::Model > LoadXGBoostJSONModel(const char *filename)
load a model file generated by XGBoost (dmlc/xgboost). The model file must contain a decision tree en...