Treelite
xgboost_json.h
Go to the documentation of this file.
1 
8 #include <rapidjson/document.h>
9 #include <treelite/tree.h>
10 
11 #include <memory>
12 #include <stack>
13 #include <string>
14 #include <vector>
15 
16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
18 namespace treelite {
19 namespace details {
20 
21 class BaseHandler;
22 
24 class Delegator {
25  public:
27  virtual void pop_delegate() = 0;
29  virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
30 };
31 
34  : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
35  public:
40  explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41  delegator{parent_delegator} {};
42 
43  virtual bool Null() { return false; }
44  virtual bool Bool(bool) { return false; }
45  virtual bool Int(int) { return false; }
46  virtual bool Uint(unsigned) { return false; }
47  virtual bool Int64(int64_t) { return false; }
48  virtual bool Uint64(uint64_t) { return false; }
49  virtual bool Double(double) { return false; }
50  virtual bool String(const char *, std::size_t, bool) {
51  return false;
52  }
53  virtual bool StartObject() { return false; }
54  virtual bool Key(const char *str, std::size_t length, bool) {
55  set_cur_key(str, length);
56  return true;
57  }
58  virtual bool EndObject(std::size_t) { return pop_handler(); }
59  virtual bool StartArray() { return false; }
60  virtual bool EndArray(std::size_t) { return pop_handler(); }
61 
62  protected:
63  /* \brief build handler of indicated type and push it onto delegator's stack
64  * \param args ... any args required to build handler
65  */
66  template <typename HandlerType, typename... ArgsTypes>
67  bool push_handler(ArgsTypes &... args) {
68  if (auto parent = BaseHandler::delegator.lock()) {
69  parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
70  return true;
71  } else {
72  return false;
73  }
74  }
75 
76  /* \brief if current JSON key is the indicated string, build handler of
77  * indicated type and push it onto delegator's stack
78  * \param key the expected key
79  * \param args ... any args required to build handler
80  */
81  template <typename HandlerType, typename... ArgsTypes>
82  bool push_key_handler(std::string key, ArgsTypes &... args) {
83  if (check_cur_key(key)) {
84  push_handler<HandlerType, ArgsTypes...>(args...);
85  return true;
86  } else {
87  return false;
88  }
89  }
90  /* \brief pop handler off of delegator's stack, relinquishing parsing */
91  bool pop_handler();
92  /* \brief store current JSON key
93  * \param str the key to store
94  * \param length the length of the str char array
95  */
96  void set_cur_key(const char *str, std::size_t length);
97  /* \brief retrieve current JSON key */
98  const std::string &get_cur_key();
99  /* \brief check if current JSON key is indicated key
100  * \param query_key the value to compare against current JSON key
101  */
102  bool check_cur_key(const std::string &query_key);
103  /* \brief if current JSON key is the indicated string, assign value to output
104  * \param key the JSON key for this output
105  * \param value the value to be assigned
106  * \param output reference to object to which the value should be assigned
107  */
108  template <typename ValueType>
109  bool assign_value(const std::string &key,
110  ValueType &&value,
111  ValueType &output);
112  template <typename ValueType>
113  bool assign_value(const std::string &key,
114  const ValueType &value,
115  ValueType &output);
116 
117  private:
118  /* \brief the delegator which delegated parsing responsibility to this handler */
119  std::weak_ptr<Delegator> delegator;
120  /* \brief the JSON key for the object currently being parsed */
121  std::string cur_key;
122 };
123 
125 class IgnoreHandler : public BaseHandler {
126  public:
128  bool Null() override;
129  bool Bool(bool b) override;
130  bool Int(int i) override;
131  bool Uint(unsigned u) override;
132  bool Int64(int64_t i) override;
133  bool Uint64(uint64_t u) override;
134  bool Double(double d) override;
135  bool String(const char *str, std::size_t length, bool copy) override;
136  bool StartObject() override;
137  bool Key(const char *str, std::size_t length, bool copy) override;
138  bool StartArray() override;
139 };
140 
142 template <typename OutputType> class OutputHandler : public BaseHandler {
143  public:
149  OutputHandler(std::weak_ptr<Delegator> parent_delegator,
150  OutputType &output_param)
151  : BaseHandler{parent_delegator}, output{output_param} {};
152  OutputHandler(std::weak_ptr<Delegator> parent_delegator,
153  OutputType &&output) = delete;
154 
155  protected:
156  /* \brief the output value constructed or modified during parsing */
157  OutputType &output;
158 };
159 
161 template <typename ElemType, typename HandlerType = BaseHandler>
162 class ArrayHandler : public OutputHandler<std::vector<ElemType>> {
163  public:
165 
166  /* Note: This method will only be instantiated (and therefore override the
167  * base `bool Bool(bool)` method) if ElemType is bool. */
168  bool Bool(ElemType b) {
169  this->output.push_back(b);
170  return true;
171  }
172  template <typename ArgType, typename IntType = ElemType>
173  typename std::enable_if<std::is_integral<IntType>::value, bool>::type
174  store_int(ArgType i) {
175  this->output.push_back(static_cast<ElemType>(i));
176  return true;
177  }
178 
179  template <typename ArgType, typename IntType = ElemType>
180  typename std::enable_if<!std::is_integral<IntType>::value, bool>::type
181  store_int(ArgType) {
182  return false;
183  }
184 
185  bool Int(int i) override { return store_int<int>(i); }
186  bool Uint(unsigned u) override { return store_int<unsigned>(u); }
187  bool Int64(int64_t i) override { return store_int<int64_t>(i); }
188  bool Uint64(uint64_t u) override { return store_int<uint64_t>(u); }
189 
190  /* Note: This method will only be instantiated (and therefore override the
191  * base `bool Double(double)` method) if ElemType is double. */
192  bool Double(ElemType d) {
193  this->output.push_back(d);
194  return true;
195  }
196 
197  template <typename StringType = ElemType>
198  typename std::enable_if<std::is_same<StringType, std::string>::value,
199  bool>::type
200  store_string(const char *str, std::size_t length, bool copy) {
201  this->output.push_back(ElemType{str, length});
202  return true;
203  }
204  template <typename StringType = ElemType>
205  typename std::enable_if<!std::is_same<StringType, std::string>::value,
206  bool>::type
207  store_string(const char *, std::size_t, bool) {
208  return false;
209  }
210 
211  bool String(const char *str, std::size_t length, bool copy) override {
212  return store_string(str, length, copy);
213  }
214 
215  bool StartObject(std::true_type) {
216  this->output.emplace_back();
217  return this->template push_handler<HandlerType, ElemType>(
218  this->output.back());
219  }
220 
221  bool StartObject(std::false_type) { return false; }
222 
223  bool StartObject() override {
224  return StartObject(
225  std::integral_constant<bool, std::is_base_of<OutputHandler<ElemType>,
226  HandlerType>::value>{});
227  }
228 };
229 
232  public:
234 
235  bool String(const char *str, std::size_t length, bool copy) override;
236 };
237 
240  public:
242  bool StartArray() override;
243 
244  bool StartObject() override;
245 
246  bool Uint(unsigned u) override;
247 
248  bool EndObject(std::size_t memberCount) override;
249 
250  private:
251  std::vector<double> loss_changes;
252  std::vector<double> sum_hessian;
253  std::vector<double> base_weights;
254  std::vector<int> left_children;
255  std::vector<int> right_children;
256  std::vector<int> parents;
257  std::vector<int> split_indices;
258  std::vector<int> split_type;
259  std::vector<int> categories_segments;
260  std::vector<int> categories_sizes;
261  std::vector<int> categories_nodes;
262  std::vector<int> categories;
263  std::vector<double> split_conditions;
264  std::vector<bool> default_left;
265  int num_nodes = 0;
266 };
267 
269 class GBTreeModelHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
271  bool StartArray() override;
272  bool StartObject() override;
273 };
274 
276 class GradientBoosterHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
277  public:
279  bool String(const char *str, std::size_t length, bool copy) override;
280  bool StartObject() override;
281 };
282 
284 class ObjectiveHandler : public OutputHandler<std::string> {
286 
287  bool StartObject() override;
288 
289  bool String(const char *str, std::size_t length, bool copy) override;
290 };
291 
293 class LearnerParamHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
294  public:
296  bool String(const char *str, std::size_t length, bool copy) override;
297 };
298 
301  std::string objective_name;
302 };
303 
305 class LearnerHandler : public OutputHandler<XGBoostModelHandle> {
306  public:
308  bool StartObject() override;
309  bool EndObject(std::size_t memberCount) override;
310 
311  private:
312  std::string objective;
313 };
314 
316 class XGBoostModelHandler : public OutputHandler<XGBoostModelHandle> {
317  public:
319  bool StartArray() override;
320  bool StartObject() override;
321  bool EndObject(std::size_t memberCount) override;
322 
323  private:
324  std::vector<unsigned> version;
325 };
326 
328 class RootHandler : public OutputHandler<std::unique_ptr<treelite::Model>> {
329  public:
331  bool StartObject() override;
332  private:
333  XGBoostModelHandle handle;
334 };
335 
338  : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
339  public Delegator {
340 
341  public:
343  static std::shared_ptr<DelegatedHandler> create() {
344  struct make_shared_enabler : public DelegatedHandler {};
345 
346  std::shared_ptr<DelegatedHandler> new_handler =
347  std::make_shared<make_shared_enabler>();
348  new_handler->push_delegate(std::make_shared<RootHandler>(
349  new_handler,
350  new_handler->result));
351  return new_handler;
352  }
353 
358  std::shared_ptr<BaseHandler> new_delegate) override {
359  delegates.push(new_delegate);
360  }
364  void pop_delegate() override {
365  delegates.pop();
366  }
367  std::unique_ptr<treelite::Model> get_result();
368  bool Null();
369  bool Bool(bool b);
370  bool Int(int i);
371  bool Uint(unsigned u);
372  bool Int64(int64_t i);
373  bool Uint64(uint64_t u);
374  bool Double(double d);
375  bool String(const char *str, std::size_t length, bool copy);
376  bool StartObject();
377  bool Key(const char *str, std::size_t length, bool copy);
378  bool EndObject(std::size_t memberCount);
379  bool StartArray();
380  bool EndArray(std::size_t elementCount);
381 
382  private:
383  DelegatedHandler() : delegates{}, result{treelite::Model::Create<float, float>()} {};
384 
385  std::stack<std::shared_ptr<BaseHandler>> delegates;
386  std::unique_ptr<treelite::Model> result;
387 };
388 
389 } // namespace details
390 } // namespace treelite
391 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
Definition: xgboost_json.h:364
class for handling delegation of JSON handling
Definition: xgboost_json.h:24
handler for ObjectiveHandler objects from XGBoost schema
Definition: xgboost_json.h:284
handler for XGBoostModel objects from XGBoost schema
Definition: xgboost_json.h:316
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
Definition: xgboost_json.h:239
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
Definition: xgboost_json.h:269
handler which delegates JSON parsing to stack of delegates
Definition: xgboost_json.h:337
handler for array of objects of given type
Definition: xgboost_json.h:162
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Definition: xgboost_json.h:343
base class for parsing all JSON objects
Definition: xgboost_json.h:33
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator&#39;s stack
Definition: xgboost_json.h:40
base handler for updating some output object
Definition: xgboost_json.h:142
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator&#39;s stack
Definition: xgboost_json.h:149
handler for TreeParam objects from XGBoost schema
Definition: xgboost_json.h:231
handler for Learner objects from XGBoost schema
Definition: xgboost_json.h:305
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
Definition: xgboost_json.h:357
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
Definition: xgboost_json.h:293
handler for root object of XGBoost schema
Definition: xgboost_json.h:328
JSON handler that ignores all delegated input.
Definition: xgboost_json.h:125
handler for GradientBoosterHandler objects from XGBoost schema
Definition: xgboost_json.h:276