8 #include <rapidjson/document.h> 16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 29 virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
34 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
40 explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41 delegator{parent_delegator} {};
43 virtual bool Null() {
return false; }
44 virtual bool Bool(
bool) {
return false; }
45 virtual bool Int(
int) {
return false; }
46 virtual bool Uint(
unsigned) {
return false; }
47 virtual bool Int64(int64_t) {
return false; }
48 virtual bool Uint64(uint64_t) {
return false; }
49 virtual bool Double(
double) {
return false; }
50 virtual bool String(
const char *, std::size_t,
bool) {
53 virtual bool StartObject() {
return false; }
54 virtual bool Key(
const char *str, std::size_t length,
bool) {
55 set_cur_key(str, length);
58 virtual bool EndObject(std::size_t) {
return pop_handler(); }
59 virtual bool StartArray() {
return false; }
60 virtual bool EndArray(std::size_t) {
return pop_handler(); }
66 template <
typename HandlerType,
typename... ArgsTypes>
67 bool push_handler(ArgsTypes &... args) {
68 if (
auto parent = BaseHandler::delegator.lock()) {
69 parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
81 template <
typename HandlerType,
typename... ArgsTypes>
82 bool push_key_handler(std::string key, ArgsTypes &... args) {
83 if (check_cur_key(key)) {
84 push_handler<HandlerType, ArgsTypes...>(args...);
96 void set_cur_key(
const char *str, std::size_t length);
98 const std::string &get_cur_key();
102 bool check_cur_key(
const std::string &query_key);
108 template <
typename ValueType>
109 bool assign_value(
const std::string &key,
112 template <
typename ValueType>
113 bool assign_value(
const std::string &key,
114 const ValueType &value,
119 std::weak_ptr<Delegator> delegator;
128 bool Null()
override;
129 bool Bool(
bool b)
override;
130 bool Int(
int i)
override;
131 bool Uint(
unsigned u)
override;
132 bool Int64(int64_t i)
override;
133 bool Uint64(uint64_t u)
override;
134 bool Double(
double d)
override;
135 bool String(
const char *str, std::size_t length,
bool copy)
override;
136 bool StartObject()
override;
137 bool Key(
const char *str, std::size_t length,
bool copy)
override;
138 bool StartArray()
override;
150 OutputType &output_param)
151 :
BaseHandler{parent_delegator}, output{output_param} {};
153 OutputType &&output) =
delete;
161 template <
typename ElemType,
typename HandlerType = BaseHandler>
168 bool Bool(ElemType b) {
169 this->output.push_back(b);
172 template <
typename ArgType,
typename IntType = ElemType>
173 typename std::enable_if<std::is_integral<IntType>::value,
bool>::type
174 store_int(ArgType i) {
175 this->output.push_back(static_cast<ElemType>(i));
179 template <
typename ArgType,
typename IntType = ElemType>
180 typename std::enable_if<!std::is_integral<IntType>::value,
bool>::type
185 bool Int(
int i)
override {
return store_int<int>(i); }
186 bool Uint(
unsigned u)
override {
return store_int<unsigned>(u); }
187 bool Int64(int64_t i)
override {
return store_int<int64_t>(i); }
188 bool Uint64(uint64_t u)
override {
return store_int<uint64_t>(u); }
192 bool Double(ElemType d) {
193 this->output.push_back(d);
197 template <
typename StringType = ElemType>
198 typename std::enable_if<std::is_same<StringType, std::string>::value,
200 store_string(
const char *str, std::size_t length,
bool copy) {
201 this->output.push_back(ElemType{str, length});
204 template <
typename StringType = ElemType>
205 typename std::enable_if<!std::is_same<StringType, std::string>::value,
207 store_string(
const char *, std::size_t,
bool) {
211 bool String(
const char *str, std::size_t length,
bool copy)
override {
212 return store_string(str, length, copy);
215 bool StartObject(std::true_type) {
216 this->output.emplace_back();
217 return this->
template push_handler<HandlerType, ElemType>(
218 this->output.back());
221 bool StartObject(std::false_type) {
return false; }
223 bool StartObject()
override {
226 HandlerType>::value>{});
231 std::unique_ptr<treelite::Model> model_ptr;
233 std::vector<unsigned> version;
234 std::vector<int> tree_info;
235 std::string objective_name;
243 bool String(
const char *str, std::size_t length,
bool copy)
override;
250 bool StartArray()
override;
252 bool StartObject()
override;
254 bool Uint(
unsigned u)
override;
256 bool EndObject(std::size_t memberCount)
override;
259 std::vector<double> loss_changes;
260 std::vector<double> sum_hessian;
261 std::vector<double> base_weights;
262 std::vector<int> left_children;
263 std::vector<int> right_children;
264 std::vector<int> parents;
265 std::vector<int> split_indices;
266 std::vector<int> split_type;
267 std::vector<int> categories_segments;
268 std::vector<int> categories_sizes;
269 std::vector<int> categories_nodes;
270 std::vector<int> categories;
271 std::vector<double> split_conditions;
272 std::vector<bool> default_left;
279 bool StartArray()
override;
280 bool StartObject()
override;
287 bool String(
const char *str, std::size_t length,
bool copy)
override;
288 bool StartArray()
override;
289 bool StartObject()
override;
290 bool EndObject(std::size_t memberCount)
override;
293 std::vector<double> weight_drop;
300 bool StartObject()
override;
302 bool String(
const char *str, std::size_t length,
bool copy)
override;
309 bool String(
const char *str, std::size_t length,
bool copy)
override;
316 bool StartObject()
override;
317 bool EndObject(std::size_t memberCount)
override;
318 bool StartArray()
override;
321 std::string objective;
328 bool StartArray()
override;
329 bool StartObject()
override;
336 bool StartArray()
override;
337 bool StartObject()
override;
338 bool EndObject(std::size_t memberCount)
override;
345 bool StartObject()
override;
350 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
356 std::shared_ptr<DelegatedHandler> new_handler =
357 std::make_shared<make_shared_enabler>();
362 static std::shared_ptr<DelegatedHandler>
create() {
363 std::shared_ptr<DelegatedHandler> new_handler = create_empty();
364 new_handler->push_delegate(std::make_shared<RootHandler>(
366 new_handler->result));
374 std::shared_ptr<BaseHandler> new_delegate)
override {
375 delegates.push(new_delegate);
387 bool Uint(
unsigned u);
388 bool Int64(int64_t i);
389 bool Uint64(uint64_t u);
390 bool Double(
double d);
391 bool String(
const char *str, std::size_t length,
bool copy);
393 bool Key(
const char *str, std::size_t length,
bool copy);
394 bool EndObject(std::size_t memberCount);
396 bool EndArray(std::size_t elementCount);
401 result{treelite::Model::Create<float, float>(),
nullptr, {}, {},
""}
406 std::stack<std::shared_ptr<BaseHandler>> delegates;
412 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
class for handling delegation of JSON handling
handler for ObjectiveHandler objects from XGBoost schema
handler for XGBoostModel objects from XGBoost schema
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
handler which delegates JSON parsing to stack of delegates
handler for array of objects of given type
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
base class for parsing all JSON objects
static std::shared_ptr< DelegatedHandler > create_empty()
create DelegatedHandler with empty stack
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator's stack
base handler for updating some output object
handler for XGBoost checkpoint
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator's stack
handler for TreeParam objects from XGBoost schema
handler for Learner objects from XGBoost schema
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
handler for root object of XGBoost schema
JSON handler that ignores all delegated input.
handler for GradientBoosterHandler objects from XGBoost schema