8 #include <rapidjson/document.h> 16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 29 virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
34 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
40 explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41 delegator{parent_delegator} {};
43 virtual bool Null() {
return false; }
44 virtual bool Bool(
bool) {
return false; }
45 virtual bool Int(
int) {
return false; }
46 virtual bool Uint(
unsigned) {
return false; }
47 virtual bool Int64(int64_t) {
return false; }
48 virtual bool Uint64(uint64_t) {
return false; }
49 virtual bool Double(
double) {
return false; }
50 virtual bool String(
const char *, std::size_t,
bool) {
53 virtual bool StartObject() {
return false; }
54 virtual bool Key(
const char *str, std::size_t length,
bool) {
55 set_cur_key(str, length);
58 virtual bool EndObject(std::size_t) {
return pop_handler(); }
59 virtual bool StartArray() {
return false; }
60 virtual bool EndArray(std::size_t) {
return pop_handler(); }
66 template <
typename HandlerType,
typename... ArgsTypes>
67 bool push_handler(ArgsTypes &... args) {
68 if (
auto parent = BaseHandler::delegator.lock()) {
69 parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
81 template <
typename HandlerType,
typename... ArgsTypes>
82 bool push_key_handler(std::string key, ArgsTypes &... args) {
83 if (check_cur_key(key)) {
84 push_handler<HandlerType, ArgsTypes...>(args...);
96 void set_cur_key(
const char *str, std::size_t length);
98 const std::string &get_cur_key();
102 bool check_cur_key(
const std::string &query_key);
108 template <
typename ValueType>
109 bool assign_value(
const std::string &key,
112 template <
typename ValueType>
113 bool assign_value(
const std::string &key,
114 const ValueType &value,
119 std::weak_ptr<Delegator> delegator;
128 bool Null()
override;
129 bool Bool(
bool b)
override;
130 bool Int(
int i)
override;
131 bool Uint(
unsigned u)
override;
132 bool Int64(int64_t i)
override;
133 bool Uint64(uint64_t u)
override;
134 bool Double(
double d)
override;
135 bool String(
const char *str, std::size_t length,
bool copy)
override;
136 bool StartObject()
override;
137 bool Key(
const char *str, std::size_t length,
bool copy)
override;
138 bool StartArray()
override;
150 OutputType &output_param)
151 :
BaseHandler{parent_delegator}, output{output_param} {};
153 OutputType &&output) =
delete;
161 template <
typename ElemType,
typename HandlerType = BaseHandler>
168 bool Bool(ElemType b) {
169 this->output.push_back(b);
172 template <
typename ArgType,
typename IntType = ElemType>
173 typename std::enable_if<std::is_integral<IntType>::value,
bool>::type
174 store_int(ArgType i) {
175 this->output.push_back(static_cast<ElemType>(i));
179 template <
typename ArgType,
typename IntType = ElemType>
180 typename std::enable_if<!std::is_integral<IntType>::value,
bool>::type
185 bool Int(
int i)
override {
return store_int<int>(i); }
186 bool Uint(
unsigned u)
override {
return store_int<unsigned>(u); }
187 bool Int64(int64_t i)
override {
return store_int<int64_t>(i); }
188 bool Uint64(uint64_t u)
override {
return store_int<uint64_t>(u); }
192 bool Double(ElemType d) {
193 this->output.push_back(d);
197 template <
typename StringType = ElemType>
198 typename std::enable_if<std::is_same<StringType, std::string>::value,
200 store_string(
const char *str, std::size_t length,
bool copy) {
201 this->output.push_back(ElemType{str, length});
204 template <
typename StringType = ElemType>
205 typename std::enable_if<!std::is_same<StringType, std::string>::value,
207 store_string(
const char *, std::size_t,
bool) {
211 bool String(
const char *str, std::size_t length,
bool copy)
override {
212 return store_string(str, length, copy);
215 bool StartObject(std::true_type) {
216 this->output.emplace_back();
217 return this->
template push_handler<HandlerType, ElemType>(
218 this->output.back());
221 bool StartObject(std::false_type) {
return false; }
223 bool StartObject()
override {
226 HandlerType>::value>{});
235 bool String(
const char *str, std::size_t length,
bool copy)
override;
242 bool StartArray()
override;
244 bool StartObject()
override;
246 bool Uint(
unsigned u)
override;
248 bool EndObject(std::size_t memberCount)
override;
251 std::vector<double> loss_changes;
252 std::vector<double> sum_hessian;
253 std::vector<double> base_weights;
254 std::vector<int> left_children;
255 std::vector<int> right_children;
256 std::vector<int> parents;
257 std::vector<int> split_indices;
258 std::vector<int> split_type;
259 std::vector<int> categories_segments;
260 std::vector<int> categories_sizes;
261 std::vector<int> categories_nodes;
262 std::vector<int> categories;
263 std::vector<double> split_conditions;
264 std::vector<bool> default_left;
271 bool StartArray()
override;
272 bool StartObject()
override;
279 bool String(
const char *str, std::size_t length,
bool copy)
override;
280 bool StartArray()
override;
281 bool StartObject()
override;
282 bool EndObject(std::size_t memberCount)
override;
285 std::vector<double> weight_drop;
292 bool StartObject()
override;
294 bool String(
const char *str, std::size_t length,
bool copy)
override;
301 bool String(
const char *str, std::size_t length,
bool copy)
override;
306 std::string objective_name;
313 bool StartObject()
override;
314 bool EndObject(std::size_t memberCount)
override;
315 bool StartArray()
override;
318 std::string objective;
325 bool StartArray()
override;
326 bool StartObject()
override;
327 bool EndObject(std::size_t memberCount)
override;
330 std::vector<unsigned> version;
337 bool StartObject()
override;
344 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
349 static std::shared_ptr<DelegatedHandler>
create() {
352 std::shared_ptr<DelegatedHandler> new_handler =
353 std::make_shared<make_shared_enabler>();
356 new_handler->result));
364 std::shared_ptr<BaseHandler> new_delegate)
override {
365 delegates.push(new_delegate);
373 std::unique_ptr<treelite::Model> get_result();
377 bool Uint(
unsigned u);
378 bool Int64(int64_t i);
379 bool Uint64(uint64_t u);
380 bool Double(
double d);
381 bool String(
const char *str, std::size_t length,
bool copy);
383 bool Key(
const char *str, std::size_t length,
bool copy);
384 bool EndObject(std::size_t memberCount);
386 bool EndArray(std::size_t elementCount);
389 DelegatedHandler() : delegates{}, result{treelite::Model::Create<float, float>()} {};
391 std::stack<std::shared_ptr<BaseHandler>> delegates;
392 std::unique_ptr<treelite::Model> result;
397 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
class for handling delegation of JSON handling
handler for ObjectiveHandler objects from XGBoost schema
handler for XGBoostModel objects from XGBoost schema
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
handler which delegates JSON parsing to stack of delegates
handler for array of objects of given type
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
base class for parsing all JSON objects
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator's stack
base handler for updating some output object
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator's stack
handler for TreeParam objects from XGBoost schema
handler for Learner objects from XGBoost schema
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
handler for root object of XGBoost schema
JSON handler that ignores all delegated input.
handler for GradientBoosterHandler objects from XGBoost schema