Treelite
json_serializer.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/logging.h>
11 #include <rapidjson/ostreamwrapper.h>
12 #include <rapidjson/writer.h>
13 #include <rapidjson/prettywriter.h>
14 #include <ostream>
15 #include <type_traits>
16 #include <cstdint>
17 #include <cstddef>
18 
19 namespace {
20 
21 template <typename WriterType, typename T,
22  typename std::enable_if<std::is_integral<T>::value, bool>::type = true>
23 void WriteElement(WriterType& writer, T e) {
24  writer.Uint64(static_cast<uint64_t>(e));
25 }
26 
27 template <typename WriterType, typename T,
28  typename std::enable_if<std::is_floating_point<T>::value, bool>::type = true>
29 void WriteElement(WriterType& writer, T e) {
30  writer.Double(static_cast<double>(e));
31 }
32 
33 template <typename WriterType>
34 void WriteString(WriterType& writer, const std::string& str) {
35  writer.String(str.data(), str.size());
36 }
37 
38 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
39 void WriteNode(WriterType& writer,
41  int node_id) {
42  writer.StartObject();
43 
44  writer.Key("node_id");
45  writer.Int(node_id);
46  if (tree.IsLeaf(node_id)) {
47  writer.Key("leaf_value");
48  if (tree.HasLeafVector(node_id)) {
49  writer.StartArray();
50  for (LeafOutputType e : tree.LeafVector(node_id)) {
51  WriteElement(writer, e);
52  }
53  writer.EndArray();
54  } else {
55  WriteElement(writer, tree.LeafValue(node_id));
56  }
57  } else {
58  writer.Key("split_feature_id");
59  writer.Uint(tree.SplitIndex(node_id));
60  writer.Key("default_left");
61  writer.Bool(tree.DefaultLeft(node_id));
62  writer.Key("split_type");
63  auto split_type = tree.SplitType(node_id);
64  WriteString(writer, treelite::SplitFeatureTypeName(split_type));
65  if (split_type == treelite::SplitFeatureType::kNumerical) {
66  writer.Key("comparison_op");
67  WriteString(writer, treelite::OpName(tree.ComparisonOp(node_id)));
68  writer.Key("threshold");
69  writer.Double(tree.Threshold(node_id));
70  } else if (split_type == treelite::SplitFeatureType::kCategorical) {
71  writer.Key("categories_list_right_child");
72  writer.Bool(tree.CategoriesListRightChild(node_id));
73  writer.Key("matching_categories");
74  writer.StartArray();
75  for (uint32_t e : tree.MatchingCategories(node_id)) {
76  writer.Uint(e);
77  }
78  writer.EndArray();
79  }
80  writer.Key("left_child");
81  writer.Int(tree.LeftChild(node_id));
82  writer.Key("right_child");
83  writer.Int(tree.RightChild(node_id));
84  }
85  if (tree.HasDataCount(node_id)) {
86  writer.Key("data_count");
87  writer.Uint64(tree.DataCount(node_id));
88  }
89  if (tree.HasSumHess(node_id)) {
90  writer.Key("sum_hess");
91  writer.Double(tree.SumHess(node_id));
92  }
93  if (tree.HasGain(node_id)) {
94  writer.Key("gain");
95  writer.Double(tree.Gain(node_id));
96  }
97 
98  writer.EndObject();
99 }
100 
101 template <typename WriterType>
102 void SerializeTaskParamToJSON(WriterType& writer, treelite::TaskParam task_param) {
103  writer.StartObject();
104 
105  writer.Key("output_type");
106  WriteString(writer, treelite::OutputTypeToString(task_param.output_type));
107  writer.Key("grove_per_class");
108  writer.Bool(task_param.grove_per_class);
109  writer.Key("num_class");
110  writer.Uint(task_param.num_class);
111  writer.Key("leaf_vector_size");
112  writer.Uint(task_param.leaf_vector_size);
113 
114  writer.EndObject();
115 }
116 
117 template <typename WriterType>
118 void SerializeModelParamToJSON(WriterType& writer, treelite::ModelParam model_param) {
119  writer.StartObject();
120 
121  writer.Key("pred_transform");
122  WriteString(writer, std::string(model_param.pred_transform));
123  writer.Key("sigmoid_alpha");
124  writer.Double(model_param.sigmoid_alpha);
125  writer.Key("global_bias");
126  writer.Double(model_param.global_bias);
127 
128  writer.EndObject();
129 }
130 
131 } // anonymous namespace
132 
133 namespace treelite {
134 
135 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
136 void DumpTreeAsJSON(WriterType& writer, const Tree<ThresholdType, LeafOutputType>& tree) {
137  writer.StartObject();
138 
139  writer.Key("num_nodes");
140  writer.Int(tree.num_nodes);
141  writer.Key("nodes");
142  writer.StartArray();
143  for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
144  WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree, i);
145  }
146  writer.EndArray();
147 
148  writer.EndObject();
149 
150  // Basic checks
151  TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
152  TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
153  TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
154 }
155 
156 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
157 void DumpModelAsJSON(WriterType& writer,
158  const ModelImpl<ThresholdType, LeafOutputType>& model) {
159  writer.StartObject();
160 
161  writer.Key("num_feature");
162  writer.Int(model.num_feature);
163  writer.Key("task_type");
164  WriteString(writer, TaskTypeToString(model.task_type));
165  writer.Key("average_tree_output");
166  writer.Bool(model.average_tree_output);
167  writer.Key("task_param");
168  SerializeTaskParamToJSON(writer, model.task_param);
169  writer.Key("model_param");
170  SerializeModelParamToJSON(writer, model.param);
171  writer.Key("trees");
172  writer.StartArray();
173  for (const Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
174  DumpTreeAsJSON(writer, tree);
175  }
176  writer.EndArray();
177 
178  writer.EndObject();
179 }
180 
181 template <typename ThresholdType, typename LeafOutputType>
182 void
183 ModelImpl<ThresholdType, LeafOutputType>::DumpAsJSON(std::ostream& fo, bool pretty_print) const {
184  rapidjson::OStreamWrapper os(fo);
185  if (pretty_print) {
186  rapidjson::PrettyWriter<rapidjson::OStreamWrapper> writer(os);
187  writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
188  DumpModelAsJSON(writer, *this);
189  } else {
190  rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
191  DumpModelAsJSON(writer, *this);
192  }
193 }
194 
195 template void ModelImpl<float, uint32_t>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
196 template void ModelImpl<float, float>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
197 template void ModelImpl<double, uint32_t>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
198 template void ModelImpl<double, double>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
199 
200 } // namespace treelite
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:430
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:465
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:494
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:183
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:172
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:400
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:614
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:198
in-memory representation of a decision tree
Definition: tree.h:213
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:191
float global_bias
global bias of the model
Definition: tree.h:621
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:372
double SumHess(int nid) const
get hessian sum
Definition: tree.h:487
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:509
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:458
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:393
int LeftChild(int nid) const
Getters.
Definition: tree.h:351
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:441
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:472
double Gain(int nid) const
get gain value
Definition: tree.h:501
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:358
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:175
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:379
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:480
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:386
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:423
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:416
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:606