11 #include <rapidjson/ostreamwrapper.h> 12 #include <rapidjson/writer.h> 13 #include <rapidjson/prettywriter.h> 15 #include <type_traits> 21 template <
typename WriterType,
typename T,
22 typename std::enable_if<std::is_integral<T>::value,
bool>::type =
true>
23 void WriteElement(WriterType& writer, T e) {
24 writer.Uint64(static_cast<uint64_t>(e));
27 template <
typename WriterType,
typename T,
28 typename std::enable_if<std::is_floating_point<T>::value,
bool>::type =
true>
29 void WriteElement(WriterType& writer, T e) {
30 writer.Double(static_cast<double>(e));
33 template <
typename WriterType>
34 void WriteString(WriterType& writer,
const std::string& str) {
35 writer.String(str.data(), str.size());
38 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
39 void WriteNode(WriterType& writer,
44 writer.Key(
"node_id");
46 if (tree.
IsLeaf(node_id)) {
47 writer.Key(
"leaf_value");
50 for (LeafOutputType e : tree.
LeafVector(node_id)) {
51 WriteElement(writer, e);
55 WriteElement(writer, tree.
LeafValue(node_id));
58 writer.Key(
"split_feature_id");
60 writer.Key(
"default_left");
62 writer.Key(
"split_type");
63 auto split_type = tree.
SplitType(node_id);
64 WriteString(writer, treelite::SplitFeatureTypeName(split_type));
65 if (split_type == treelite::SplitFeatureType::kNumerical) {
66 writer.Key(
"comparison_op");
67 WriteString(writer, treelite::OpName(tree.
ComparisonOp(node_id)));
68 writer.Key(
"threshold");
70 }
else if (split_type == treelite::SplitFeatureType::kCategorical) {
71 writer.Key(
"categories_list_right_child");
73 writer.Key(
"matching_categories");
80 writer.Key(
"left_child");
82 writer.Key(
"right_child");
86 writer.Key(
"data_count");
90 writer.Key(
"sum_hess");
91 writer.Double(tree.
SumHess(node_id));
95 writer.Double(tree.
Gain(node_id));
101 template <
typename WriterType>
103 writer.StartObject();
105 writer.Key(
"output_type");
106 WriteString(writer, treelite::OutputTypeToString(task_param.
output_type));
107 writer.Key(
"grove_per_class");
109 writer.Key(
"num_class");
111 writer.Key(
"leaf_vector_size");
117 template <
typename WriterType>
119 writer.StartObject();
121 writer.Key(
"pred_transform");
123 writer.Key(
"sigmoid_alpha");
125 writer.Key(
"ratio_c");
126 writer.Double(model_param.
ratio_c);
127 writer.Key(
"global_bias");
137 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
138 void DumpTreeAsJSON(WriterType& writer,
const Tree<ThresholdType, LeafOutputType>& tree) {
139 writer.StartObject();
141 writer.Key(
"num_nodes");
142 writer.Int(tree.num_nodes);
145 for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
146 WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree, i);
153 TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
154 TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
155 TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
158 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
159 void DumpModelAsJSON(WriterType& writer,
160 const ModelImpl<ThresholdType, LeafOutputType>& model) {
161 writer.StartObject();
163 writer.Key(
"num_feature");
164 writer.Int(model.num_feature);
165 writer.Key(
"task_type");
166 WriteString(writer, TaskTypeToString(model.task_type));
167 writer.Key(
"average_tree_output");
168 writer.Bool(model.average_tree_output);
169 writer.Key(
"task_param");
170 SerializeTaskParamToJSON(writer, model.task_param);
171 writer.Key(
"model_param");
172 SerializeModelParamToJSON(writer, model.param);
175 for (
const Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
176 DumpTreeAsJSON(writer, tree);
183 template <
typename ThresholdType,
typename LeafOutputType>
185 ModelImpl<ThresholdType, LeafOutputType>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const {
186 rapidjson::OStreamWrapper os(fo);
188 rapidjson::PrettyWriter<rapidjson::OStreamWrapper> writer(os);
189 writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
190 DumpModelAsJSON(writer, *
this);
192 rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
193 DumpModelAsJSON(writer, *
this);
197 template void ModelImpl<float, uint32_t>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
198 template void ModelImpl<float, float>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
199 template void ModelImpl<double, uint32_t>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
200 template void ModelImpl<double, double>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
Operator ComparisonOp(int nid) const
get comparison operator
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
in-memory representation of a decision tree
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
double SumHess(int nid) const
get hessian sum
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
int LeftChild(int nid) const
Getters.
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function