8 #include <rapidjson/document.h> 16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ 29 virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
34 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
40 explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41 delegator{parent_delegator} {};
43 virtual bool Null() {
return false; }
44 virtual bool Bool(
bool) {
return false; }
45 virtual bool Int(
int) {
return false; }
46 virtual bool Uint(
unsigned) {
return false; }
47 virtual bool Int64(int64_t) {
return false; }
48 virtual bool Uint64(uint64_t) {
return false; }
49 virtual bool Double(
double) {
return false; }
50 virtual bool String(
const char *, std::size_t,
bool) {
53 virtual bool StartObject() {
return false; }
54 virtual bool Key(
const char *str, std::size_t length,
bool) {
55 set_cur_key(str, length);
58 virtual bool EndObject(std::size_t) {
return pop_handler(); }
59 virtual bool StartArray() {
return false; }
60 virtual bool EndArray(std::size_t) {
return pop_handler(); }
66 template <
typename HandlerType,
typename... ArgsTypes>
67 bool push_handler(ArgsTypes &... args) {
68 if (
auto parent = BaseHandler::delegator.lock()) {
69 parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
81 template <
typename HandlerType,
typename... ArgsTypes>
82 bool push_key_handler(std::string key, ArgsTypes &... args) {
83 if (check_cur_key(key)) {
84 push_handler<HandlerType, ArgsTypes...>(args...);
96 void set_cur_key(
const char *str, std::size_t length);
98 const std::string &get_cur_key();
102 bool check_cur_key(
const std::string &query_key);
108 template <
typename ValueType>
109 bool assign_value(
const std::string &key,
112 template <
typename ValueType>
113 bool assign_value(
const std::string &key,
114 const ValueType &value,
119 std::weak_ptr<Delegator> delegator;
128 bool Null()
override;
129 bool Bool(
bool b)
override;
130 bool Int(
int i)
override;
131 bool Uint(
unsigned u)
override;
132 bool Int64(int64_t i)
override;
133 bool Uint64(uint64_t u)
override;
134 bool Double(
double d)
override;
135 bool String(
const char *str, std::size_t length,
bool copy)
override;
136 bool StartObject()
override;
137 bool Key(
const char *str, std::size_t length,
bool copy)
override;
138 bool StartArray()
override;
150 OutputType &output_param)
151 :
BaseHandler{parent_delegator}, output{output_param} {};
153 OutputType &&output) =
delete;
161 template <
typename ElemType,
typename HandlerType = BaseHandler>
168 bool Bool(ElemType b) {
169 this->output.push_back(b);
172 template <
typename ArgType,
typename IntType = ElemType>
173 typename std::enable_if<std::is_integral<IntType>::value,
bool>::type
174 store_int(ArgType i) {
175 this->output.push_back(static_cast<ElemType>(i));
179 template <
typename ArgType,
typename IntType = ElemType>
180 typename std::enable_if<!std::is_integral<IntType>::value,
bool>::type
185 bool Int(
int i)
override {
return store_int<int>(i); }
186 bool Uint(
unsigned u)
override {
return store_int<unsigned>(u); }
187 bool Int64(int64_t i)
override {
return store_int<int64_t>(i); }
188 bool Uint64(uint64_t u)
override {
return store_int<uint64_t>(u); }
192 bool Double(ElemType d) {
193 this->output.push_back(d);
197 template <
typename StringType = ElemType>
198 typename std::enable_if<std::is_same<StringType, std::string>::value,
200 store_string(
const char *str, std::size_t length,
bool copy) {
201 this->output.push_back(ElemType{str, length});
204 template <
typename StringType = ElemType>
205 typename std::enable_if<!std::is_same<StringType, std::string>::value,
207 store_string(
const char *, std::size_t,
bool) {
211 bool String(
const char *str, std::size_t length,
bool copy)
override {
212 return store_string(str, length, copy);
215 bool StartObject(std::true_type) {
216 this->output.emplace_back();
217 return this->
template push_handler<HandlerType, ElemType>(
218 this->output.back());
221 bool StartObject(std::false_type) {
return false; }
223 bool StartObject()
override {
226 HandlerType>::value>{});
235 bool String(
const char *str, std::size_t length,
bool copy)
override;
242 bool StartArray()
override;
244 bool StartObject()
override;
246 bool Uint(
unsigned u)
override;
248 bool EndObject(std::size_t memberCount)
override;
251 std::vector<double> loss_changes;
252 std::vector<double> sum_hessian;
253 std::vector<double> base_weights;
254 std::vector<int> left_children;
255 std::vector<int> right_children;
256 std::vector<int> parents;
257 std::vector<int> split_indices;
258 std::vector<int> split_type;
259 std::vector<int> categories_segments;
260 std::vector<int> categories_sizes;
261 std::vector<int> categories_nodes;
262 std::vector<int> categories;
263 std::vector<double> split_conditions;
264 std::vector<bool> default_left;
271 bool StartArray()
override;
272 bool StartObject()
override;
279 bool String(
const char *str, std::size_t length,
bool copy)
override;
280 bool StartObject()
override;
287 bool StartObject()
override;
289 bool String(
const char *str, std::size_t length,
bool copy)
override;
296 bool String(
const char *str, std::size_t length,
bool copy)
override;
301 std::string objective_name;
308 bool StartObject()
override;
309 bool EndObject(std::size_t memberCount)
override;
312 std::string objective;
319 bool StartArray()
override;
320 bool StartObject()
override;
321 bool EndObject(std::size_t memberCount)
override;
324 std::vector<unsigned> version;
331 bool StartObject()
override;
338 :
public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
343 static std::shared_ptr<DelegatedHandler>
create() {
346 std::shared_ptr<DelegatedHandler> new_handler =
347 std::make_shared<make_shared_enabler>();
350 new_handler->result));
358 std::shared_ptr<BaseHandler> new_delegate)
override {
359 delegates.push(new_delegate);
367 std::unique_ptr<treelite::Model> get_result();
371 bool Uint(
unsigned u);
372 bool Int64(int64_t i);
373 bool Uint64(uint64_t u);
374 bool Double(
double d);
375 bool String(
const char *str, std::size_t length,
bool copy);
377 bool Key(
const char *str, std::size_t length,
bool copy);
378 bool EndObject(std::size_t memberCount);
380 bool EndArray(std::size_t elementCount);
383 DelegatedHandler() : delegates{}, result{treelite::Model::Create<float, float>()} {};
385 std::stack<std::shared_ptr<BaseHandler>> delegates;
386 std::unique_ptr<treelite::Model> result;
391 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_ void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
class for handling delegation of JSON handling
handler for ObjectiveHandler objects from XGBoost schema
handler for XGBoostModel objects from XGBoost schema
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
handler which delegates JSON parsing to stack of delegates
handler for array of objects of given type
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
base class for parsing all JSON objects
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator's stack
base handler for updating some output object
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator's stack
handler for TreeParam objects from XGBoost schema
handler for Learner objects from XGBoost schema
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
handler for root object of XGBoost schema
JSON handler that ignores all delegated input.
handler for GradientBoosterHandler objects from XGBoost schema