8 #include <rapidjson/document.h> 16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 29 virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
34 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
40 explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41 delegator{parent_delegator} {};
43 virtual bool Null() {
return false; }
44 virtual bool Bool(
bool) {
return false; }
45 virtual bool Int(
int) {
return false; }
46 virtual bool Uint(
unsigned) {
return false; }
47 virtual bool Int64(int64_t) {
return false; }
48 virtual bool Uint64(uint64_t) {
return false; }
49 virtual bool Double(
double) {
return false; }
50 virtual bool String(
const char *, std::size_t,
bool) {
53 virtual bool StartObject() {
return false; }
54 virtual bool Key(
const char *str, std::size_t length,
bool) {
55 set_cur_key(str, length);
58 virtual bool EndObject(std::size_t) {
return pop_handler(); }
59 virtual bool StartArray() {
return false; }
60 virtual bool EndArray(std::size_t) {
return pop_handler(); }
66 template <
typename HandlerType,
typename... ArgsTypes>
67 bool push_handler(ArgsTypes &... args) {
68 if (
auto parent = BaseHandler::delegator.lock()) {
69 parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
81 template <
typename HandlerType,
typename... ArgsTypes>
82 bool push_key_handler(std::string key, ArgsTypes &... args) {
83 if (check_cur_key(key)) {
84 push_handler<HandlerType, ArgsTypes...>(args...);
96 void set_cur_key(
const char *str, std::size_t length);
98 const std::string &get_cur_key();
102 bool check_cur_key(
const std::string &query_key);
108 template <
typename ValueType>
109 bool assign_value(
const std::string &key,
112 template <
typename ValueType>
113 bool assign_value(
const std::string &key,
114 const ValueType &value,
119 std::weak_ptr<Delegator> delegator;
128 bool Null()
override;
129 bool Bool(
bool b)
override;
130 bool Int(
int i)
override;
131 bool Uint(
unsigned u)
override;
132 bool Int64(int64_t i)
override;
133 bool Uint64(uint64_t u)
override;
134 bool Double(
double d)
override;
135 bool String(
const char *str, std::size_t length,
bool copy)
override;
136 bool StartObject()
override;
137 bool Key(
const char *str, std::size_t length,
bool copy)
override;
138 bool StartArray()
override;
150 OutputType &output_param)
151 :
BaseHandler{parent_delegator}, output{output_param} {};
153 OutputType &&output) =
delete;
161 template <
typename ElemType,
typename HandlerType = BaseHandler>
168 bool Bool(ElemType b) {
169 this->output.push_back(b);
172 template <
typename ArgType,
typename IntType = ElemType>
173 typename std::enable_if<std::is_integral<IntType>::value,
bool>::type
174 store_int(ArgType i) {
175 this->output.push_back(static_cast<ElemType>(i));
179 template <
typename ArgType,
typename IntType = ElemType>
180 typename std::enable_if<!std::is_integral<IntType>::value,
bool>::type
185 bool Int(
int i)
override {
return store_int<int>(i); }
186 bool Uint(
unsigned u)
override {
return store_int<unsigned>(u); }
187 bool Int64(int64_t i)
override {
return store_int<int64_t>(i); }
188 bool Uint64(uint64_t u)
override {
return store_int<uint64_t>(u); }
192 bool Double(ElemType d) {
193 this->output.push_back(d);
197 template <
typename StringType = ElemType>
198 typename std::enable_if<std::is_same<StringType, std::string>::value,
200 store_string(
const char *str, std::size_t length,
bool copy) {
201 this->output.push_back(ElemType{str, length});
204 template <
typename StringType = ElemType>
205 typename std::enable_if<!std::is_same<StringType, std::string>::value,
207 store_string(
const char *, std::size_t,
bool) {
211 bool String(
const char *str, std::size_t length,
bool copy)
override {
212 return store_string(str, length, copy);
215 bool StartObject(std::true_type) {
216 this->output.emplace_back();
217 return this->
template push_handler<HandlerType, ElemType>(
218 this->output.back());
221 bool StartObject(std::false_type) {
return false; }
223 bool StartObject()
override {
226 HandlerType>::value>{});
235 bool String(
const char *str, std::size_t length,
bool copy)
override;
242 bool StartArray()
override;
244 bool StartObject()
override;
246 bool Uint(
unsigned u)
override;
248 bool EndObject(std::size_t memberCount)
override;
251 std::vector<double> loss_changes;
252 std::vector<double> sum_hessian;
253 std::vector<double> base_weights;
254 std::vector<int> left_children;
255 std::vector<int> right_children;
256 std::vector<int> parents;
257 std::vector<int> split_indices;
258 std::vector<int> split_type;
259 std::vector<int> categories_segments;
260 std::vector<int> categories_sizes;
261 std::vector<int> categories_nodes;
262 std::vector<int> categories;
263 std::vector<double> split_conditions;
264 std::vector<bool> default_left;
271 bool StartArray()
override;
272 bool StartObject()
override;
279 bool String(
const char *str, std::size_t length,
bool copy)
override;
280 bool StartArray()
override;
281 bool StartObject()
override;
282 bool EndObject(std::size_t memberCount)
override;
285 std::vector<double> weight_drop;
292 bool StartObject()
override;
294 bool String(
const char *str, std::size_t length,
bool copy)
override;
301 bool String(
const char *str, std::size_t length,
bool copy)
override;
306 std::vector<unsigned> version;
307 std::string objective_name;
314 bool StartObject()
override;
315 bool EndObject(std::size_t memberCount)
override;
316 bool StartArray()
override;
319 std::string objective;
326 bool StartArray()
override;
327 bool StartObject()
override;
334 bool StartArray()
override;
335 bool StartObject()
override;
336 bool EndObject(std::size_t memberCount)
override;
343 bool StartObject()
override;
350 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
356 std::shared_ptr<DelegatedHandler> new_handler =
357 std::make_shared<make_shared_enabler>();
362 static std::shared_ptr<DelegatedHandler>
create() {
363 std::shared_ptr<DelegatedHandler> new_handler = create_empty();
364 new_handler->push_delegate(std::make_shared<RootHandler>(
366 new_handler->result));
374 std::shared_ptr<BaseHandler> new_delegate)
override {
375 delegates.push(new_delegate);
383 std::unique_ptr<treelite::Model> get_result();
387 bool Uint(
unsigned u);
388 bool Int64(int64_t i);
389 bool Uint64(uint64_t u);
390 bool Double(
double d);
391 bool String(
const char *str, std::size_t length,
bool copy);
393 bool Key(
const char *str, std::size_t length,
bool copy);
394 bool EndObject(std::size_t memberCount);
396 bool EndArray(std::size_t elementCount);
399 DelegatedHandler() : delegates{}, result{treelite::Model::Create<float, float>()} {};
401 std::stack<std::shared_ptr<BaseHandler>> delegates;
402 std::unique_ptr<treelite::Model> result;
407 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
class for handling delegation of JSON handling
handler for ObjectiveHandler objects from XGBoost schema
handler for XGBoostModel objects from XGBoost schema
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
handler which delegates JSON parsing to stack of delegates
handler for array of objects of given type
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
base class for parsing all JSON objects
static std::shared_ptr< DelegatedHandler > create_empty()
create DelegatedHandler with empty stack
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator's stack
base handler for updating some output object
handler for XGBoost checkpoint
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator's stack
handler for TreeParam objects from XGBoost schema
handler for Learner objects from XGBoost schema
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
handler for root object of XGBoost schema
JSON handler that ignores all delegated input.
handler for GradientBoosterHandler objects from XGBoost schema