Treelite
json_serializer.cc
Go to the documentation of this file.
1 
9 #include <treelite/tree.h>
10 #include <treelite/logging.h>
11 #include <rapidjson/ostreamwrapper.h>
12 #include <rapidjson/writer.h>
13 #include <rapidjson/prettywriter.h>
14 #include <ostream>
15 #include <type_traits>
16 #include <cstdint>
17 #include <cstddef>
18 
19 namespace {
20 
21 template <typename WriterType, typename T,
22  typename std::enable_if<std::is_integral<T>::value, bool>::type = true>
23 void WriteElement(WriterType& writer, T e) {
24  writer.Uint64(static_cast<uint64_t>(e));
25 }
26 
27 template <typename WriterType, typename T,
28  typename std::enable_if<std::is_floating_point<T>::value, bool>::type = true>
29 void WriteElement(WriterType& writer, T e) {
30  writer.Double(static_cast<double>(e));
31 }
32 
33 template <typename WriterType>
34 void WriteString(WriterType& writer, const std::string& str) {
35  writer.String(str.data(), str.size());
36 }
37 
38 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
39 void WriteNode(WriterType& writer,
41  int node_id) {
42  writer.StartObject();
43 
44  writer.Key("node_id");
45  writer.Int(node_id);
46  if (tree.IsLeaf(node_id)) {
47  writer.Key("leaf_value");
48  if (tree.HasLeafVector(node_id)) {
49  writer.StartArray();
50  for (LeafOutputType e : tree.LeafVector(node_id)) {
51  WriteElement(writer, e);
52  }
53  writer.EndArray();
54  } else {
55  WriteElement(writer, tree.LeafValue(node_id));
56  }
57  } else {
58  writer.Key("split_feature_id");
59  writer.Uint(tree.SplitIndex(node_id));
60  writer.Key("default_left");
61  writer.Bool(tree.DefaultLeft(node_id));
62  writer.Key("split_type");
63  auto split_type = tree.SplitType(node_id);
64  WriteString(writer, treelite::SplitFeatureTypeName(split_type));
65  if (split_type == treelite::SplitFeatureType::kNumerical) {
66  writer.Key("comparison_op");
67  WriteString(writer, treelite::OpName(tree.ComparisonOp(node_id)));
68  writer.Key("threshold");
69  writer.Double(tree.Threshold(node_id));
70  } else if (split_type == treelite::SplitFeatureType::kCategorical) {
71  writer.Key("categories_list_right_child");
72  writer.Bool(tree.CategoriesListRightChild(node_id));
73  writer.Key("matching_categories");
74  writer.StartArray();
75  for (uint32_t e : tree.MatchingCategories(node_id)) {
76  writer.Uint(e);
77  }
78  writer.EndArray();
79  }
80  writer.Key("left_child");
81  writer.Int(tree.LeftChild(node_id));
82  writer.Key("right_child");
83  writer.Int(tree.RightChild(node_id));
84  }
85  if (tree.HasDataCount(node_id)) {
86  writer.Key("data_count");
87  writer.Uint64(tree.DataCount(node_id));
88  }
89  if (tree.HasSumHess(node_id)) {
90  writer.Key("sum_hess");
91  writer.Double(tree.SumHess(node_id));
92  }
93  if (tree.HasGain(node_id)) {
94  writer.Key("gain");
95  writer.Double(tree.Gain(node_id));
96  }
97 
98  writer.EndObject();
99 }
100 
101 template <typename WriterType>
102 void SerializeTaskParamToJSON(WriterType& writer, treelite::TaskParam task_param) {
103  writer.StartObject();
104 
105  writer.Key("output_type");
106  WriteString(writer, treelite::OutputTypeToString(task_param.output_type));
107  writer.Key("grove_per_class");
108  writer.Bool(task_param.grove_per_class);
109  writer.Key("num_class");
110  writer.Uint(task_param.num_class);
111  writer.Key("leaf_vector_size");
112  writer.Uint(task_param.leaf_vector_size);
113 
114  writer.EndObject();
115 }
116 
117 template <typename WriterType>
118 void SerializeModelParamToJSON(WriterType& writer, treelite::ModelParam model_param) {
119  writer.StartObject();
120 
121  writer.Key("pred_transform");
122  WriteString(writer, std::string(model_param.pred_transform));
123  writer.Key("sigmoid_alpha");
124  writer.Double(model_param.sigmoid_alpha);
125  writer.Key("ratio_c");
126  writer.Double(model_param.ratio_c);
127  writer.Key("global_bias");
128  writer.Double(model_param.global_bias);
129 
130  writer.EndObject();
131 }
132 
133 } // anonymous namespace
134 
135 namespace treelite {
136 
137 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
138 void DumpTreeAsJSON(WriterType& writer, const Tree<ThresholdType, LeafOutputType>& tree) {
139  writer.StartObject();
140 
141  writer.Key("num_nodes");
142  writer.Int(tree.num_nodes);
143  writer.Key("has_categorical_split");
144  writer.Bool(tree.has_categorical_split_);
145  writer.Key("nodes");
146  writer.StartArray();
147  for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
148  WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree, i);
149  }
150  writer.EndArray();
151 
152  writer.EndObject();
153 
154  // Basic checks
155  TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
156  TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
157  TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
158 }
159 
160 template <typename WriterType, typename ThresholdType, typename LeafOutputType>
161 void DumpModelAsJSON(WriterType& writer,
162  const ModelImpl<ThresholdType, LeafOutputType>& model) {
163  writer.StartObject();
164 
165  writer.Key("num_feature");
166  writer.Int(model.num_feature);
167  writer.Key("task_type");
168  WriteString(writer, TaskTypeToString(model.task_type));
169  writer.Key("average_tree_output");
170  writer.Bool(model.average_tree_output);
171  writer.Key("task_param");
172  SerializeTaskParamToJSON(writer, model.task_param);
173  writer.Key("model_param");
174  SerializeModelParamToJSON(writer, model.param);
175  writer.Key("trees");
176  writer.StartArray();
177  for (const Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
178  DumpTreeAsJSON(writer, tree);
179  }
180  writer.EndArray();
181 
182  writer.EndObject();
183 }
184 
185 template <typename ThresholdType, typename LeafOutputType>
186 void
187 ModelImpl<ThresholdType, LeafOutputType>::DumpAsJSON(std::ostream& fo, bool pretty_print) const {
188  rapidjson::OStreamWrapper os(fo);
189  if (pretty_print) {
190  rapidjson::PrettyWriter<rapidjson::OStreamWrapper> writer(os);
191  writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
192  DumpModelAsJSON(writer, *this);
193  } else {
194  rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
195  DumpModelAsJSON(writer, *this);
196  }
197 }
198 
199 template void ModelImpl<float, uint32_t>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
200 template void ModelImpl<float, float>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
201 template void ModelImpl<double, uint32_t>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
202 template void ModelImpl<double, double>::DumpAsJSON(std::ostream& fo, bool pretty_print) const;
203 
204 } // namespace treelite
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:485
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:520
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:549
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:185
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:174
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:455
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:676
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:200
in-memory representation of a decision tree
Definition: tree.h:215
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:193
float global_bias
global bias of the model
Definition: tree.h:691
std::uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:427
double SumHess(int nid) const
get hessian sum
Definition: tree.h:542
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:564
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:513
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:448
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:684
int LeftChild(int nid) const
Getters.
Definition: tree.h:406
std::uint64_t DataCount(int nid) const
get data count
Definition: tree.h:527
double Gain(int nid) const
get gain value
Definition: tree.h:556
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:413
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:177
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:434
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:535
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:441
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:478
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:471
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:496
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function
Definition: tree.h:668