Treelite
json_serializer.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/logging.h>
11 #include <rapidjson/ostreamwrapper.h>
12 #include <rapidjson/writer.h>
13 #include <ostream>
14 #include <cstdint>
15 #include <cstddef>
16 
17 namespace {
18 
19 template <typename WriterType>
20 void WriteElement(WriterType& writer, double e) {
21  writer.Double(e);
22 }
23 
24 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
25 void WriteNode(WriterType& writer,
27  writer.StartObject();
28 
29  writer.Key("cleft");
30  writer.Int(node.cleft_);
31  writer.Key("cright");
32  writer.Int(node.cright_);
33  writer.Key("split_index");
34  writer.Uint(node.sindex_ & ((1U << 31U) - 1U));
35  writer.Key("default_left");
36  writer.Bool((node.sindex_ >> 31U) != 0);
37  if (node.cleft_ == -1) {
38  writer.Key("leaf_value");
39  writer.Double(node.info_.leaf_value);
40  } else {
41  writer.Key("threshold");
42  writer.Double(node.info_.threshold);
43  }
44  if (node.data_count_present_) {
45  writer.Key("data_count");
46  writer.Uint64(node.data_count_);
47  }
48  if (node.sum_hess_present_) {
49  writer.Key("sum_hess");
50  writer.Double(node.sum_hess_);
51  }
52  if (node.gain_present_) {
53  writer.Key("gain");
54  writer.Double(node.gain_);
55  }
56  writer.Key("split_type");
57  writer.Int(static_cast<int8_t>(node.split_type_));
58  writer.Key("cmp");
59  writer.Int(static_cast<int8_t>(node.cmp_));
60  writer.Key("categories_list_right_child");
61  writer.Bool(node.categories_list_right_child_);
62 
63  writer.EndObject();
64 }
65 
66 template <typename WriterType, typename ElementType>
67 void WriteContiguousArray(WriterType& writer,
69  writer.StartArray();
70  for (std::size_t i = 0; i < array.Size(); ++i) {
71  WriteElement(writer, array[i]);
72  }
73  writer.EndArray();
74 }
75 
76 template <typename WriterType>
77 void SerializeTaskParamToJSON(WriterType& writer, treelite::TaskParam task_param) {
78  writer.StartObject();
79 
80  writer.Key("output_type");
81  writer.Uint(static_cast<uint8_t>(task_param.output_type));
82  writer.Key("grove_per_class");
83  writer.Bool(task_param.grove_per_class);
84  writer.Key("num_class");
85  writer.Uint(task_param.num_class);
86  writer.Key("leaf_vector_size");
87  writer.Uint(task_param.leaf_vector_size);
88 
89  writer.EndObject();
90 }
91 
92 template <typename WriterType>
93 void SerializeModelParamToJSON(WriterType& writer, treelite::ModelParam model_param) {
94  writer.StartObject();
95 
96  writer.Key("pred_transform");
97  std::string pred_transform(model_param.pred_transform);
98  writer.String(pred_transform.data(), pred_transform.size());
99  writer.Key("sigmoid_alpha");
100  writer.Double(model_param.sigmoid_alpha);
101  writer.Key("global_bias");
102  writer.Double(model_param.global_bias);
103 
104  writer.EndObject();
105 }
106 
107 } // anonymous namespace
108 
109 namespace treelite {
110 
111 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
112 void SerializeTreeToJSON(WriterType& writer, const Tree<ThresholdType, LeafOutputType>& tree) {
113  writer.StartObject();
114 
115  writer.Key("num_nodes");
116  writer.Int(tree.num_nodes);
117  writer.Key("leaf_vector");
118  WriteContiguousArray(writer, tree.leaf_vector_);
119  writer.Key("leaf_vector_offset");
120  WriteContiguousArray(writer, tree.leaf_vector_offset_);
121  writer.Key("matching_categories");
122  WriteContiguousArray(writer, tree.matching_categories_);
123  writer.Key("matching_categories_offset");
124  WriteContiguousArray(writer, tree.matching_categories_offset_);
125  writer.Key("nodes");
126  writer.StartArray();
127  for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
128  WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree.nodes_[i]);
129  }
130  writer.EndArray();
131 
132  writer.EndObject();
133 
134  // Sanity check
135  TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
136  TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.leaf_vector_offset_.Size());
137  TREELITE_CHECK_EQ(tree.leaf_vector_offset_.Back(), tree.leaf_vector_.Size());
138  TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
139  TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
140 }
141 
142 template <typename ThresholdType, typename LeafOutputType>
143 void ModelImpl<ThresholdType, LeafOutputType>::SerializeToJSON(std::ostream& fo) const {
144  rapidjson::OStreamWrapper os(fo);
145  rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
146 
147  writer.StartObject();
148 
149  writer.Key("num_feature");
150  writer.Int(num_feature);
151  writer.Key("task_type");
152  writer.Uint(static_cast<uint8_t>(task_type));
153  writer.Key("average_tree_output");
154  writer.Bool(average_tree_output);
155  writer.Key("task_param");
156  SerializeTaskParamToJSON(writer, task_param);
157  writer.Key("model_param");
158  SerializeModelParamToJSON(writer, param);
159  writer.Key("trees");
160  writer.StartArray();
161  for (const Tree<ThresholdType, LeafOutputType>& tree : trees) {
162  SerializeTreeToJSON(writer, tree);
163  }
164  writer.EndArray();
165 
166  writer.EndObject();
167 }
168 
169 template void ModelImpl<float, uint32_t>::SerializeToJSON(std::ostream& fo) const;
170 template void ModelImpl<float, float>::SerializeToJSON(std::ostream& fo) const;
171 template void ModelImpl<double, uint32_t>::SerializeToJSON(std::ostream& fo) const;
172 template void ModelImpl<double, double>::SerializeToJSON(std::ostream& fo) const;
173 
174 } // namespace treelite
SplitFeatureType split_type_
feature split type
Definition: tree.h:228
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:240
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:215
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:168
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:157
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:234
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:593
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:236
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:183
tree node
Definition: tree.h:193
int32_t cleft_
pointer to left and right children
Definition: tree.h:203
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:222
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:176
float global_bias
global bias of the model
Definition: tree.h:600
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:226
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:160
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:238
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:208
Info info_
storage for leaf value or decision threshold
Definition: tree.h:210
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:585