Treelite
xgboost_json.h
Go to the documentation of this file.
1 
8 #include <rapidjson/document.h>
9 #include <treelite/tree.h>
10 
11 #include <memory>
12 #include <stack>
13 #include <string>
14 #include <vector>
15 
16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
18 namespace treelite {
19 namespace details {
20 
21 class BaseHandler;
22 
24 class Delegator {
25  public:
27  virtual void pop_delegate() = 0;
29  virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
30 };
31 
34  : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
35  public:
40  explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41  delegator{parent_delegator} {};
42 
43  virtual bool Null() { return false; }
44  virtual bool Bool(bool) { return false; }
45  virtual bool Int(int) { return false; }
46  virtual bool Uint(unsigned) { return false; }
47  virtual bool Int64(int64_t) { return false; }
48  virtual bool Uint64(uint64_t) { return false; }
49  virtual bool Double(double) { return false; }
50  virtual bool String(const char *, std::size_t, bool) {
51  return false;
52  }
53  virtual bool StartObject() { return false; }
54  virtual bool Key(const char *str, std::size_t length, bool) {
55  set_cur_key(str, length);
56  return true;
57  }
58  virtual bool EndObject(std::size_t) { return pop_handler(); }
59  virtual bool StartArray() { return false; }
60  virtual bool EndArray(std::size_t) { return pop_handler(); }
61 
62  protected:
63  /* \brief build handler of indicated type and push it onto delegator's stack
64  * \param args ... any args required to build handler
65  */
66  template <typename HandlerType, typename... ArgsTypes>
67  bool push_handler(ArgsTypes &... args) {
68  if (auto parent = BaseHandler::delegator.lock()) {
69  parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
70  return true;
71  } else {
72  return false;
73  }
74  }
75 
76  /* \brief if current JSON key is the indicated string, build handler of
77  * indicated type and push it onto delegator's stack
78  * \param key the expected key
79  * \param args ... any args required to build handler
80  */
81  template <typename HandlerType, typename... ArgsTypes>
82  bool push_key_handler(std::string key, ArgsTypes &... args) {
83  if (check_cur_key(key)) {
84  push_handler<HandlerType, ArgsTypes...>(args...);
85  return true;
86  } else {
87  return false;
88  }
89  }
90  /* \brief pop handler off of delegator's stack, relinquishing parsing */
91  bool pop_handler();
92  /* \brief store current JSON key
93  * \param str the key to store
94  * \param length the length of the str char array
95  */
96  void set_cur_key(const char *str, std::size_t length);
97  /* \brief retrieve current JSON key */
98  const std::string &get_cur_key();
99  /* \brief check if current JSON key is indicated key
100  * \param query_key the value to compare against current JSON key
101  */
102  bool check_cur_key(const std::string &query_key);
103  /* \brief if current JSON key is the indicated string, assign value to output
104  * \param key the JSON key for this output
105  * \param value the value to be assigned
106  * \param output reference to object to which the value should be assigned
107  */
108  template <typename ValueType>
109  bool assign_value(const std::string &key,
110  ValueType &&value,
111  ValueType &output);
112  template <typename ValueType>
113  bool assign_value(const std::string &key,
114  const ValueType &value,
115  ValueType &output);
116 
117  private:
118  /* \brief the delegator which delegated parsing responsibility to this handler */
119  std::weak_ptr<Delegator> delegator;
120  /* \brief the JSON key for the object currently being parsed */
121  std::string cur_key;
122 };
123 
125 class IgnoreHandler : public BaseHandler {
126  public:
128  bool Null() override;
129  bool Bool(bool b) override;
130  bool Int(int i) override;
131  bool Uint(unsigned u) override;
132  bool Int64(int64_t i) override;
133  bool Uint64(uint64_t u) override;
134  bool Double(double d) override;
135  bool String(const char *str, std::size_t length, bool copy) override;
136  bool StartObject() override;
137  bool Key(const char *str, std::size_t length, bool copy) override;
138  bool StartArray() override;
139 };
140 
142 template <typename OutputType> class OutputHandler : public BaseHandler {
143  public:
149  OutputHandler(std::weak_ptr<Delegator> parent_delegator,
150  OutputType &output_param)
151  : BaseHandler{parent_delegator}, output{output_param} {};
152  OutputHandler(std::weak_ptr<Delegator> parent_delegator,
153  OutputType &&output) = delete;
154 
155  protected:
156  /* \brief the output value constructed or modified during parsing */
157  OutputType &output;
158 };
159 
161 template <typename ElemType, typename HandlerType = BaseHandler>
162 class ArrayHandler : public OutputHandler<std::vector<ElemType>> {
163  public:
165 
166  /* Note: This method will only be instantiated (and therefore override the
167  * base `bool Bool(bool)` method) if ElemType is bool. */
168  bool Bool(ElemType b) {
169  this->output.push_back(b);
170  return true;
171  }
172  template <typename ArgType, typename IntType = ElemType>
173  typename std::enable_if<std::is_integral<IntType>::value, bool>::type
174  store_int(ArgType i) {
175  this->output.push_back(static_cast<ElemType>(i));
176  return true;
177  }
178 
179  template <typename ArgType, typename IntType = ElemType>
180  typename std::enable_if<!std::is_integral<IntType>::value, bool>::type
181  store_int(ArgType) {
182  return false;
183  }
184 
185  bool Int(int i) override { return store_int<int>(i); }
186  bool Uint(unsigned u) override { return store_int<unsigned>(u); }
187  bool Int64(int64_t i) override { return store_int<int64_t>(i); }
188  bool Uint64(uint64_t u) override { return store_int<uint64_t>(u); }
189 
190  /* Note: This method will only be instantiated (and therefore override the
191  * base `bool Double(double)` method) if ElemType is double. */
192  bool Double(ElemType d) {
193  this->output.push_back(d);
194  return true;
195  }
196 
197  template <typename StringType = ElemType>
198  typename std::enable_if<std::is_same<StringType, std::string>::value,
199  bool>::type
200  store_string(const char *str, std::size_t length, bool copy) {
201  this->output.push_back(ElemType{str, length});
202  return true;
203  }
204  template <typename StringType = ElemType>
205  typename std::enable_if<!std::is_same<StringType, std::string>::value,
206  bool>::type
207  store_string(const char *, std::size_t, bool) {
208  return false;
209  }
210 
211  bool String(const char *str, std::size_t length, bool copy) override {
212  return store_string(str, length, copy);
213  }
214 
215  bool StartObject(std::true_type) {
216  this->output.emplace_back();
217  return this->template push_handler<HandlerType, ElemType>(
218  this->output.back());
219  }
220 
221  bool StartObject(std::false_type) { return false; }
222 
223  bool StartObject() override {
224  return StartObject(
225  std::integral_constant<bool, std::is_base_of<OutputHandler<ElemType>,
226  HandlerType>::value>{});
227  }
228 };
229 
232  public:
234 
235  bool String(const char *str, std::size_t length, bool copy) override;
236 };
237 
240  public:
242  bool StartArray() override;
243 
244  bool StartObject() override;
245 
246  bool Uint(unsigned u) override;
247 
248  bool EndObject(std::size_t memberCount) override;
249 
250  private:
251  std::vector<double> loss_changes;
252  std::vector<double> sum_hessian;
253  std::vector<double> base_weights;
254  std::vector<int> left_children;
255  std::vector<int> right_children;
256  std::vector<int> parents;
257  std::vector<int> split_indices;
258  std::vector<int> split_type;
259  std::vector<int> categories_segments;
260  std::vector<int> categories_sizes;
261  std::vector<int> categories_nodes;
262  std::vector<int> categories;
263  std::vector<double> split_conditions;
264  std::vector<bool> default_left;
265  int num_nodes = 0;
266 };
267 
269 class GBTreeModelHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
271  bool StartArray() override;
272  bool StartObject() override;
273 };
274 
276 class GradientBoosterHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
277  public:
279  bool String(const char *str, std::size_t length, bool copy) override;
280  bool StartArray() override;
281  bool StartObject() override;
282  bool EndObject(std::size_t memberCount) override;
283  private:
284  std::string name;
285  std::vector<double> weight_drop;
286 };
287 
289 class ObjectiveHandler : public OutputHandler<std::string> {
291 
292  bool StartObject() override;
293 
294  bool String(const char *str, std::size_t length, bool copy) override;
295 };
296 
298 class LearnerParamHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
299  public:
301  bool String(const char *str, std::size_t length, bool copy) override;
302 };
303 
306  std::vector<unsigned> version;
307  std::string objective_name;
308 };
309 
311 class LearnerHandler : public OutputHandler<XGBoostModelHandle> {
312  public:
314  bool StartObject() override;
315  bool EndObject(std::size_t memberCount) override;
316  bool StartArray() override;
317 
318  private:
319  std::string objective;
320 };
321 
323 class XGBoostCheckpointHandler : public OutputHandler<XGBoostModelHandle> {
324  public:
326  bool StartArray() override;
327  bool StartObject() override;
328 };
329 
331 class XGBoostModelHandler : public OutputHandler<XGBoostModelHandle> {
332  public:
334  bool StartArray() override;
335  bool StartObject() override;
336  bool EndObject(std::size_t memberCount) override;
337 };
338 
340 class RootHandler : public OutputHandler<std::unique_ptr<treelite::Model>> {
341  public:
343  bool StartObject() override;
344  private:
345  XGBoostModelHandle handle;
346 };
347 
350  : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
351  public Delegator {
352  public:
354  static std::shared_ptr<DelegatedHandler> create_empty() {
355  struct make_shared_enabler : public DelegatedHandler {};
356  std::shared_ptr<DelegatedHandler> new_handler =
357  std::make_shared<make_shared_enabler>();
358  return new_handler;
359  }
360 
362  static std::shared_ptr<DelegatedHandler> create() {
363  std::shared_ptr<DelegatedHandler> new_handler = create_empty();
364  new_handler->push_delegate(std::make_shared<RootHandler>(
365  new_handler,
366  new_handler->result));
367  return new_handler;
368  }
369 
374  std::shared_ptr<BaseHandler> new_delegate) override {
375  delegates.push(new_delegate);
376  }
380  void pop_delegate() override {
381  delegates.pop();
382  }
383  std::unique_ptr<treelite::Model> get_result();
384  bool Null();
385  bool Bool(bool b);
386  bool Int(int i);
387  bool Uint(unsigned u);
388  bool Int64(int64_t i);
389  bool Uint64(uint64_t u);
390  bool Double(double d);
391  bool String(const char *str, std::size_t length, bool copy);
392  bool StartObject();
393  bool Key(const char *str, std::size_t length, bool copy);
394  bool EndObject(std::size_t memberCount);
395  bool StartArray();
396  bool EndArray(std::size_t elementCount);
397 
398  private:
399  DelegatedHandler() : delegates{}, result{treelite::Model::Create<float, float>()} {};
400 
401  std::stack<std::shared_ptr<BaseHandler>> delegates;
402  std::unique_ptr<treelite::Model> result;
403 };
404 
405 } // namespace details
406 } // namespace treelite
407 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
Definition: xgboost_json.h:380
class for handling delegation of JSON handling
Definition: xgboost_json.h:24
handler for ObjectiveHandler objects from XGBoost schema
Definition: xgboost_json.h:289
handler for XGBoostModel objects from XGBoost schema
Definition: xgboost_json.h:331
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
Definition: xgboost_json.h:239
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
Definition: xgboost_json.h:269
handler which delegates JSON parsing to stack of delegates
Definition: xgboost_json.h:349
handler for array of objects of given type
Definition: xgboost_json.h:162
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Definition: xgboost_json.h:362
base class for parsing all JSON objects
Definition: xgboost_json.h:33
static std::shared_ptr< DelegatedHandler > create_empty()
create DelegatedHandler with empty stack
Definition: xgboost_json.h:354
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator&#39;s stack
Definition: xgboost_json.h:40
base handler for updating some output object
Definition: xgboost_json.h:142
handler for XGBoost checkpoint
Definition: xgboost_json.h:323
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator&#39;s stack
Definition: xgboost_json.h:149
handler for TreeParam objects from XGBoost schema
Definition: xgboost_json.h:231
handler for Learner objects from XGBoost schema
Definition: xgboost_json.h:311
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
Definition: xgboost_json.h:373
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
Definition: xgboost_json.h:298
handler for root object of XGBoost schema
Definition: xgboost_json.h:340
JSON handler that ignores all delegated input.
Definition: xgboost_json.h:125
handler for GradientBoosterHandler objects from XGBoost schema
Definition: xgboost_json.h:276