Treelite
json_serializer.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/logging.h>
11 #include <rapidjson/ostreamwrapper.h>
12 #include <rapidjson/writer.h>
13 #include <rapidjson/prettywriter.h>
14 #include <ostream>
15 #include <type_traits>
16 #include <cstdint>
17 #include <cstddef>
18 
19 namespace {
20 
21 template <typename WriterType, typename T,
22  typename std::enable_if<std::is_integral<T>::value, bool>::type = true>
23 void WriteElement(WriterType& writer, T e) {
24  writer.Uint64(static_cast<uint64_t>(e));
25 }
26 
27 template <typename WriterType, typename T,
28  typename std::enable_if<std::is_floating_point<T>::value, bool>::type = true>
29 void WriteElement(WriterType& writer, T e) {
30  writer.Double(static_cast<double>(e));
31 }
32 
33 template <typename WriterType>
34 void WriteString(WriterType& writer, const std::string& str) {
35  writer.String(str.data(), str.size());
36 }
37 
38 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
39 void WriteNode(WriterType& writer,
41  int node_id) {
42  writer.StartObject();
43 
44  writer.Key("node_id");
45  writer.Int(node_id);
46  if (tree.IsLeaf(node_id)) {
47  writer.Key("leaf_value");
48  if (tree.HasLeafVector(node_id)) {
49  writer.StartArray();
50  for (LeafOutputType e : tree.LeafVector(node_id)) {
51  WriteElement(writer, e);
52  }
53  writer.EndArray();
54  } else {
55  WriteElement(writer, tree.LeafValue(node_id));
56  }
57  } else {
58  writer.Key("split_feature_id");
59  writer.Uint(tree.SplitIndex(node_id));
60  writer.Key("default_left");
61  writer.Bool(tree.DefaultLeft(node_id));
62  writer.Key("split_type");
63  auto split_type = tree.SplitType(node_id);
64  WriteString(writer, treelite::SplitFeatureTypeName(split_type));
65  if (split_type == treelite::SplitFeatureType::kNumerical) {
66  writer.Key("comparison_op");
67  WriteString(writer, treelite::OpName(tree.ComparisonOp(node_id)));
68  writer.Key("threshold");
69  writer.Double(tree.Threshold(node_id));
70  } else if (split_type == treelite::SplitFeatureType::kCategorical) {
71  writer.Key("categories_list_right_child");
72  writer.Bool(tree.CategoriesListRightChild(node_id));
73  writer.Key("matching_categories");
74  writer.StartArray();
75  for (uint32_t e : tree.MatchingCategories(node_id)) {
76  writer.Uint(e);
77  }
78  writer.EndArray();
79  }
80  writer.Key("left_child");
81  writer.Int(tree.LeftChild(node_id));
82  writer.Key("right_child");
83  writer.Int(tree.RightChild(node_id));
84  }
85  if (tree.HasDataCount(node_id)) {
86  writer.Key("data_count");
87  writer.Uint64(tree.DataCount(node_id));
88  }
89  if (tree.HasSumHess(node_id)) {
90  writer.Key("sum_hess");
91  writer.Double(tree.SumHess(node_id));
92  }
93  if (tree.HasGain(node_id)) {
94  writer.Key("gain");
95  writer.Double(tree.Gain(node_id));
96  }
97 
98  writer.EndObject();
99 }
100 
101 template <typename WriterType>
102 void SerializeTaskParamToJSON(WriterType& writer, treelite::TaskParam task_param) {
103  writer.StartObject();
104 
105  writer.Key("output_type");
106  WriteString(writer, treelite::OutputTypeToString(task_param.output_type));
107  writer.Key("grove_per_class");
108  writer.Bool(task_param.grove_per_class);
109  writer.Key("num_class");
110  writer.Uint(task_param.num_class);
111  writer.Key("leaf_vector_size");
112  writer.Uint(task_param.leaf_vector_size);
113 
114  writer.EndObject();
115 }
116 
117 template <typename WriterType>
118 void SerializeModelParamToJSON(WriterType& writer, treelite::ModelParam model_param) {
119  writer.StartObject();
120 
121  writer.Key("pred_transform");
122  WriteString(writer, std::string(model_param.pred_transform));
123  writer.Key("sigmoid_alpha");
124  writer.Double(model_param.sigmoid_alpha);
125  writer.Key("ratio_c");
126  writer.Double(model_param.ratio_c);
127  writer.Key("global_bias");
128  writer.Double(model_param.global_bias);
129 
130  writer.EndObject();
131 }
132 
133 } // anonymous namespace
134 
135 namespace treelite {
136 
137 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
138 void DumpTreeAsJSON(WriterType& writer, const Tree<ThresholdType, LeafOutputType>& tree) {
139  writer.StartObject();
140 
141  writer.Key("num_nodes");
142  writer.Int(tree.num_nodes);
143  writer.Key("nodes");
144  writer.StartArray();
145  for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
146  WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree, i);
147  }
148  writer.EndArray();
149 
150  writer.EndObject();
151 
152  // Basic checks
153  TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
154  TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
155  TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
156 }
157 
158 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
159 void DumpModelAsJSON(WriterType& writer,
160  const ModelImpl<ThresholdType, LeafOutputType>& model) {
161  writer.StartObject();
162 
163  writer.Key("num_feature");
164  writer.Int(model.num_feature);
165  writer.Key("task_type");
166  WriteString(writer, TaskTypeToString(model.task_type));
167  writer.Key("average_tree_output");
168  writer.Bool(model.average_tree_output);
169  writer.Key("task_param");
170  SerializeTaskParamToJSON(writer, model.task_param);
171  writer.Key("model_param");
172  SerializeModelParamToJSON(writer, model.param);
173  writer.Key("trees");
174  writer.StartArray();
175  for (const Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
176  DumpTreeAsJSON(writer, tree);
177  }
178  writer.EndArray();
179 
180  writer.EndObject();
181 }
182 
183 template <typename ThresholdType, typename LeafOutputType>
184 void
185 ModelImpl<ThresholdType, LeafOutputType>::DumpAsJSON(std::ostream& fo, bool pretty_print) const {
186  rapidjson::OStreamWrapper os(fo);
187  if (pretty_print) {
188  rapidjson::PrettyWriter<rapidjson::OStreamWrapper> writer(os);
189  writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
190  DumpModelAsJSON(writer, *this);
191  } else {
192  rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
193  DumpModelAsJSON(writer, *this);
194  }
195 }
196 
197 template void ModelImpl<float, uint32_t>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
198 template void ModelImpl<float, float>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
199 template void ModelImpl<double, uint32_t>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
200 template void ModelImpl<double, double>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
201 
202 } // namespace treelite
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:431
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:477
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:506
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:184
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:173
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:401
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:626
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:199
in-memory representation of a decision tree
Definition: tree.h:214
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:192
float global_bias
global bias of the model
Definition: tree.h:641
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:373
double SumHess(int nid) const
get hessian sum
Definition: tree.h:499
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:521
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:470
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:394
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:634
int LeftChild(int nid) const
Getters.
Definition: tree.h:352
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:442
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:484
double Gain(int nid) const
get gain value
Definition: tree.h:513
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:359
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:176
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:380
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:492
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:387
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:424
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:417
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:618