Treelite
xgboost_json.h
Go to the documentation of this file.
1 
8 #include <rapidjson/document.h>
9 #include <treelite/tree.h>
10 
11 #include <memory>
12 #include <stack>
13 #include <string>
14 #include <vector>
15 
16 #ifndef TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
17 #define TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
18 namespace treelite {
19 namespace details {
20 
21 class BaseHandler;
22 
24 class Delegator {
25  public:
27  virtual void pop_delegate() = 0;
29  virtual void push_delegate(std::shared_ptr<BaseHandler> new_delegate) = 0;
30 };
31 
34  : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, BaseHandler> {
35  public:
40  explicit BaseHandler(std::weak_ptr<Delegator> parent_delegator) :
41  delegator{parent_delegator} {};
42 
43  virtual bool Null() { return false; }
44  virtual bool Bool(bool) { return false; }
45  virtual bool Int(int) { return false; }
46  virtual bool Uint(unsigned) { return false; }
47  virtual bool Int64(int64_t) { return false; }
48  virtual bool Uint64(uint64_t) { return false; }
49  virtual bool Double(double) { return false; }
50  virtual bool String(const char *, std::size_t, bool) {
51  return false;
52  }
53  virtual bool StartObject() { return false; }
54  virtual bool Key(const char *str, std::size_t length, bool) {
55  set_cur_key(str, length);
56  return true;
57  }
58  virtual bool EndObject(std::size_t) { return pop_handler(); }
59  virtual bool StartArray() { return false; }
60  virtual bool EndArray(std::size_t) { return pop_handler(); }
61 
62  protected:
63  /* \brief build handler of indicated type and push it onto delegator's stack
64  * \param args ... any args required to build handler
65  */
66  template <typename HandlerType, typename... ArgsTypes>
67  bool push_handler(ArgsTypes &... args) {
68  if (auto parent = BaseHandler::delegator.lock()) {
69  parent->push_delegate(std::make_shared<HandlerType>(delegator, args...));
70  return true;
71  } else {
72  return false;
73  }
74  }
75 
76  /* \brief if current JSON key is the indicated string, build handler of
77  * indicated type and push it onto delegator's stack
78  * \param key the expected key
79  * \param args ... any args required to build handler
80  */
81  template <typename HandlerType, typename... ArgsTypes>
82  bool push_key_handler(std::string key, ArgsTypes &... args) {
83  if (check_cur_key(key)) {
84  push_handler<HandlerType, ArgsTypes...>(args...);
85  return true;
86  } else {
87  return false;
88  }
89  }
90  /* \brief pop handler off of delegator's stack, relinquishing parsing */
91  bool pop_handler();
92  /* \brief store current JSON key
93  * \param str the key to store
94  * \param length the length of the str char array
95  */
96  void set_cur_key(const char *str, std::size_t length);
97  /* \brief retrieve current JSON key */
98  const std::string &get_cur_key();
99  /* \brief check if current JSON key is indicated key
100  * \param query_key the value to compare against current JSON key
101  */
102  bool check_cur_key(const std::string &query_key);
103  /* \brief if current JSON key is the indicated string, assign value to output
104  * \param key the JSON key for this output
105  * \param value the value to be assigned
106  * \param output reference to object to which the value should be assigned
107  */
108  template <typename ValueType>
109  bool assign_value(const std::string &key,
110  ValueType &&value,
111  ValueType &output);
112  template <typename ValueType>
113  bool assign_value(const std::string &key,
114  const ValueType &value,
115  ValueType &output);
116 
117  private:
118  /* \brief the delegator which delegated parsing responsibility to this handler */
119  std::weak_ptr<Delegator> delegator;
120  /* \brief the JSON key for the object currently being parsed */
121  std::string cur_key;
122 };
123 
125 class IgnoreHandler : public BaseHandler {
126  public:
128  bool Null() override;
129  bool Bool(bool b) override;
130  bool Int(int i) override;
131  bool Uint(unsigned u) override;
132  bool Int64(int64_t i) override;
133  bool Uint64(uint64_t u) override;
134  bool Double(double d) override;
135  bool String(const char *str, std::size_t length, bool copy) override;
136  bool StartObject() override;
137  bool Key(const char *str, std::size_t length, bool copy) override;
138  bool StartArray() override;
139 };
140 
142 template <typename OutputType> class OutputHandler : public BaseHandler {
143  public:
149  OutputHandler(std::weak_ptr<Delegator> parent_delegator,
150  OutputType &output_param)
151  : BaseHandler{parent_delegator}, output{output_param} {};
152  OutputHandler(std::weak_ptr<Delegator> parent_delegator,
153  OutputType &&output) = delete;
154 
155  protected:
156  /* \brief the output value constructed or modified during parsing */
157  OutputType &output;
158 };
159 
161 template <typename ElemType, typename HandlerType = BaseHandler>
162 class ArrayHandler : public OutputHandler<std::vector<ElemType>> {
163  public:
165 
166  /* Note: This method will only be instantiated (and therefore override the
167  * base `bool Bool(bool)` method) if ElemType is bool. */
168  bool Bool(ElemType b) {
169  this->output.push_back(b);
170  return true;
171  }
172  template <typename ArgType, typename IntType = ElemType>
173  typename std::enable_if<std::is_integral<IntType>::value, bool>::type
174  store_int(ArgType i) {
175  this->output.push_back(static_cast<ElemType>(i));
176  return true;
177  }
178 
179  template <typename ArgType, typename IntType = ElemType>
180  typename std::enable_if<!std::is_integral<IntType>::value, bool>::type
181  store_int(ArgType) {
182  return false;
183  }
184 
185  bool Int(int i) override { return store_int<int>(i); }
186  bool Uint(unsigned u) override { return store_int<unsigned>(u); }
187  bool Int64(int64_t i) override { return store_int<int64_t>(i); }
188  bool Uint64(uint64_t u) override { return store_int<uint64_t>(u); }
189 
190  /* Note: This method will only be instantiated (and therefore override the
191  * base `bool Double(double)` method) if ElemType is double. */
192  bool Double(ElemType d) {
193  this->output.push_back(d);
194  return true;
195  }
196 
197  template <typename StringType = ElemType>
198  typename std::enable_if<std::is_same<StringType, std::string>::value,
199  bool>::type
200  store_string(const char *str, std::size_t length, bool copy) {
201  this->output.push_back(ElemType{str, length});
202  return true;
203  }
204  template <typename StringType = ElemType>
205  typename std::enable_if<!std::is_same<StringType, std::string>::value,
206  bool>::type
207  store_string(const char *, std::size_t, bool) {
208  return false;
209  }
210 
211  bool String(const char *str, std::size_t length, bool copy) override {
212  return store_string(str, length, copy);
213  }
214 
215  bool StartObject(std::true_type) {
216  this->output.emplace_back();
217  return this->template push_handler<HandlerType, ElemType>(
218  this->output.back());
219  }
220 
221  bool StartObject(std::false_type) { return false; }
222 
223  bool StartObject() override {
224  return StartObject(
225  std::integral_constant<bool, std::is_base_of<OutputHandler<ElemType>,
226  HandlerType>::value>{});
227  }
228 };
229 
231  std::unique_ptr<treelite::Model> model_ptr;
233  std::vector<unsigned> version;
234  std::vector<int> tree_info;
235  std::string objective_name;
236 };
237 
240  public:
242 
243  bool String(const char *str, std::size_t length, bool copy) override;
244 };
245 
248  public:
250  bool StartArray() override;
251 
252  bool StartObject() override;
253 
254  bool Uint(unsigned u) override;
255 
256  bool EndObject(std::size_t memberCount) override;
257 
258  private:
259  std::vector<double> loss_changes;
260  std::vector<double> sum_hessian;
261  std::vector<double> base_weights;
262  std::vector<int> left_children;
263  std::vector<int> right_children;
264  std::vector<int> parents;
265  std::vector<int> split_indices;
266  std::vector<int> split_type;
267  std::vector<int> categories_segments;
268  std::vector<int> categories_sizes;
269  std::vector<int> categories_nodes;
270  std::vector<int> categories;
271  std::vector<double> split_conditions;
272  std::vector<bool> default_left;
273  int num_nodes = 0;
274 };
275 
277 class GBTreeModelHandler : public OutputHandler<ParsedXGBoostModel> {
279  bool StartArray() override;
280  bool StartObject() override;
281 };
282 
284 class GradientBoosterHandler : public OutputHandler<ParsedXGBoostModel> {
285  public:
287  bool String(const char *str, std::size_t length, bool copy) override;
288  bool StartArray() override;
289  bool StartObject() override;
290  bool EndObject(std::size_t memberCount) override;
291  private:
292  std::string name;
293  std::vector<double> weight_drop;
294 };
295 
297 class ObjectiveHandler : public OutputHandler<std::string> {
299 
300  bool StartObject() override;
301 
302  bool String(const char *str, std::size_t length, bool copy) override;
303 };
304 
306 class LearnerParamHandler : public OutputHandler<treelite::ModelImpl<float, float>> {
307  public:
309  bool String(const char *str, std::size_t length, bool copy) override;
310 };
311 
313 class LearnerHandler : public OutputHandler<ParsedXGBoostModel> {
314  public:
316  bool StartObject() override;
317  bool EndObject(std::size_t memberCount) override;
318  bool StartArray() override;
319 
320  private:
321  std::string objective;
322 };
323 
325 class XGBoostCheckpointHandler : public OutputHandler<ParsedXGBoostModel> {
326  public:
328  bool StartArray() override;
329  bool StartObject() override;
330 };
331 
333 class XGBoostModelHandler : public OutputHandler<ParsedXGBoostModel> {
334  public:
336  bool StartArray() override;
337  bool StartObject() override;
338  bool EndObject(std::size_t memberCount) override;
339 };
340 
342 class RootHandler : public OutputHandler<ParsedXGBoostModel> {
343  public:
345  bool StartObject() override;
346 };
347 
350  : public rapidjson::BaseReaderHandler<rapidjson::UTF8<>, DelegatedHandler>,
351  public Delegator {
352  public:
354  static std::shared_ptr<DelegatedHandler> create_empty() {
355  struct make_shared_enabler : public DelegatedHandler {};
356  std::shared_ptr<DelegatedHandler> new_handler =
357  std::make_shared<make_shared_enabler>();
358  return new_handler;
359  }
360 
362  static std::shared_ptr<DelegatedHandler> create() {
363  std::shared_ptr<DelegatedHandler> new_handler = create_empty();
364  new_handler->push_delegate(std::make_shared<RootHandler>(
365  new_handler,
366  new_handler->result));
367  return new_handler;
368  }
369 
374  std::shared_ptr<BaseHandler> new_delegate) override {
375  delegates.push(new_delegate);
376  }
380  void pop_delegate() override {
381  delegates.pop();
382  }
383  ParsedXGBoostModel get_result();
384  bool Null();
385  bool Bool(bool b);
386  bool Int(int i);
387  bool Uint(unsigned u);
388  bool Int64(int64_t i);
389  bool Uint64(uint64_t u);
390  bool Double(double d);
391  bool String(const char *str, std::size_t length, bool copy);
392  bool StartObject();
393  bool Key(const char *str, std::size_t length, bool copy);
394  bool EndObject(std::size_t memberCount);
395  bool StartArray();
396  bool EndArray(std::size_t elementCount);
397 
398  private:
400  : delegates{},
401  result{treelite::Model::Create<float, float>(), nullptr, {}, {}, ""}
402  {
403  result.model = dynamic_cast<treelite::ModelImpl<float, float>*>(result.model_ptr.get());
404  }
405 
406  std::stack<std::shared_ptr<BaseHandler>> delegates;
407  ParsedXGBoostModel result;
408 };
409 
410 } // namespace details
411 } // namespace treelite
412 #endif // TREELITE_FRONTEND_XGBOOST_XGBOOST_JSON_H_
void pop_delegate() override
pop handler off of stack, returning parsing responsibility to previous handler on stack ...
Definition: xgboost_json.h:380
class for handling delegation of JSON handling
Definition: xgboost_json.h:24
handler for ObjectiveHandler objects from XGBoost schema
Definition: xgboost_json.h:297
handler for XGBoostModel objects from XGBoost schema
Definition: xgboost_json.h:333
model structure for tree ensemble
handler for RegTree objects from XGBoost schema
Definition: xgboost_json.h:247
virtual void pop_delegate()=0
pop stack of delegate handlers
handler for GBTreeModel objects from XGBoost schema
Definition: xgboost_json.h:277
handler which delegates JSON parsing to stack of delegates
Definition: xgboost_json.h:349
handler for array of objects of given type
Definition: xgboost_json.h:162
static std::shared_ptr< DelegatedHandler > create()
create DelegatedHandler with initial RootHandler on stack
Definition: xgboost_json.h:362
base class for parsing all JSON objects
Definition: xgboost_json.h:33
static std::shared_ptr< DelegatedHandler > create_empty()
create DelegatedHandler with empty stack
Definition: xgboost_json.h:354
BaseHandler(std::weak_ptr< Delegator > parent_delegator)
construct handler to be added to given delegator&#39;s stack
Definition: xgboost_json.h:40
base handler for updating some output object
Definition: xgboost_json.h:142
handler for XGBoost checkpoint
Definition: xgboost_json.h:325
OutputHandler(std::weak_ptr< Delegator > parent_delegator, OutputType &output_param)
construct handler to be added to given delegator&#39;s stack
Definition: xgboost_json.h:149
handler for TreeParam objects from XGBoost schema
Definition: xgboost_json.h:239
handler for Learner objects from XGBoost schema
Definition: xgboost_json.h:313
void push_delegate(std::shared_ptr< BaseHandler > new_delegate) override
push new handler onto stack, delegating ongoing parsing to it
Definition: xgboost_json.h:373
virtual void push_delegate(std::shared_ptr< BaseHandler > new_delegate)=0
push new delegate handler onto stack
handler for LearnerParam objects from XGBoost schema
Definition: xgboost_json.h:306
handler for root object of XGBoost schema
Definition: xgboost_json.h:342
JSON handler that ignores all delegated input.
Definition: xgboost_json.h:125
handler for GradientBoosterHandler objects from XGBoost schema
Definition: xgboost_json.h:284