11 #include <rapidjson/ostreamwrapper.h> 12 #include <rapidjson/writer.h> 13 #include <rapidjson/prettywriter.h> 15 #include <type_traits> 21 template <
typename WriterType,
typename T,
22 typename std::enable_if<std::is_integral<T>::value,
bool>::type =
true>
23 void WriteElement(WriterType& writer, T e) {
24 writer.Uint64(static_cast<uint64_t>(e));
27 template <
typename WriterType,
typename T,
28 typename std::enable_if<std::is_floating_point<T>::value,
bool>::type =
true>
29 void WriteElement(WriterType& writer, T e) {
30 writer.Double(static_cast<double>(e));
33 template <
typename WriterType>
34 void WriteString(WriterType& writer,
const std::string& str) {
35 writer.String(str.data(), str.size());
38 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
39 void WriteNode(WriterType& writer,
44 writer.Key(
"node_id");
46 if (tree.
IsLeaf(node_id)) {
47 writer.Key(
"leaf_value");
50 for (LeafOutputType e : tree.
LeafVector(node_id)) {
51 WriteElement(writer, e);
55 WriteElement(writer, tree.
LeafValue(node_id));
58 writer.Key(
"split_feature_id");
60 writer.Key(
"default_left");
62 writer.Key(
"split_type");
63 auto split_type = tree.
SplitType(node_id);
64 WriteString(writer, treelite::SplitFeatureTypeName(split_type));
65 if (split_type == treelite::SplitFeatureType::kNumerical) {
66 writer.Key(
"comparison_op");
67 WriteString(writer, treelite::OpName(tree.
ComparisonOp(node_id)));
68 writer.Key(
"threshold");
70 }
else if (split_type == treelite::SplitFeatureType::kCategorical) {
71 writer.Key(
"categories_list_right_child");
73 writer.Key(
"matching_categories");
80 writer.Key(
"left_child");
82 writer.Key(
"right_child");
86 writer.Key(
"data_count");
90 writer.Key(
"sum_hess");
91 writer.Double(tree.
SumHess(node_id));
95 writer.Double(tree.
Gain(node_id));
101 template <
typename WriterType>
103 writer.StartObject();
105 writer.Key(
"output_type");
106 WriteString(writer, treelite::OutputTypeToString(task_param.
output_type));
107 writer.Key(
"grove_per_class");
109 writer.Key(
"num_class");
111 writer.Key(
"leaf_vector_size");
117 template <
typename WriterType>
119 writer.StartObject();
121 writer.Key(
"pred_transform");
123 writer.Key(
"sigmoid_alpha");
125 writer.Key(
"ratio_c");
126 writer.Double(model_param.
ratio_c);
127 writer.Key(
"global_bias");
137 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
138 void DumpTreeAsJSON(WriterType& writer,
const Tree<ThresholdType, LeafOutputType>& tree) {
139 writer.StartObject();
141 writer.Key(
"num_nodes");
142 writer.Int(tree.num_nodes);
143 writer.Key(
"has_categorical_split");
144 writer.Bool(tree.has_categorical_split_);
147 for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
148 WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree, i);
155 TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
156 TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
157 TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
160 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
161 void DumpModelAsJSON(WriterType& writer,
162 const ModelImpl<ThresholdType, LeafOutputType>& model) {
163 writer.StartObject();
165 writer.Key(
"num_feature");
166 writer.Int(model.num_feature);
167 writer.Key(
"task_type");
168 WriteString(writer, TaskTypeToString(model.task_type));
169 writer.Key(
"average_tree_output");
170 writer.Bool(model.average_tree_output);
171 writer.Key(
"task_param");
172 SerializeTaskParamToJSON(writer, model.task_param);
173 writer.Key(
"model_param");
174 SerializeModelParamToJSON(writer, model.param);
177 for (
const Tree<ThresholdType, LeafOutputType>& tree : model.trees) {
178 DumpTreeAsJSON(writer, tree);
185 template <
typename ThresholdType,
typename LeafOutputType>
187 ModelImpl<ThresholdType, LeafOutputType>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const {
188 rapidjson::OStreamWrapper os(fo);
190 rapidjson::PrettyWriter<rapidjson::OStreamWrapper> writer(os);
191 writer.SetFormatOptions(rapidjson::PrettyFormatOptions::kFormatSingleLineArray);
192 DumpModelAsJSON(writer, *
this);
194 rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
195 DumpModelAsJSON(writer, *
this);
199 template void ModelImpl<float, uint32_t>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
200 template void ModelImpl<float, float>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
201 template void ModelImpl<double, uint32_t>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
202 template void ModelImpl<double, double>::DumpAsJSON(std::ostream& fo,
bool pretty_print)
const;
Operator ComparisonOp(int nid) const
get comparison operator
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
in-memory representation of a decision tree
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
std::uint32_t SplitIndex(int nid) const
feature index of the node's split condition
double SumHess(int nid) const
get hessian sum
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
int LeftChild(int nid) const
Getters.
std::uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function