11 #include <rapidjson/ostreamwrapper.h> 12 #include <rapidjson/writer.h> 19 template <
typename WriterType>
20 void WriteElement(WriterType& writer,
double e) {
24 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
25 void WriteNode(WriterType& writer,
32 writer.Int(node.cright_);
33 writer.Key(
"split_index");
34 writer.Uint(node.
sindex_ & ((1U << 31U) - 1U));
35 writer.Key(
"default_left");
36 writer.Bool((node.
sindex_ >> 31U) != 0);
38 writer.Key(
"leaf_value");
39 writer.Double(node.
info_.leaf_value);
41 writer.Key(
"threshold");
42 writer.Double(node.
info_.threshold);
45 writer.Key(
"data_count");
49 writer.Key(
"sum_hess");
54 writer.Double(node.
gain_);
56 writer.Key(
"split_type");
59 writer.Int(static_cast<int8_t>(node.
cmp_));
60 writer.Key(
"categories_list_right_child");
61 writer.Bool(node.categories_list_right_child_);
66 template <
typename WriterType,
typename ElementType>
67 void WriteContiguousArray(WriterType& writer,
70 for (std::size_t i = 0; i < array.Size(); ++i) {
71 WriteElement(writer, array[i]);
76 template <
typename WriterType>
80 writer.Key(
"output_type");
81 writer.Uint(static_cast<uint8_t>(task_param.
output_type));
82 writer.Key(
"grove_per_class");
84 writer.Key(
"num_class");
86 writer.Key(
"leaf_vector_size");
92 template <
typename WriterType>
96 writer.Key(
"pred_transform");
98 writer.String(pred_transform.data(), pred_transform.size());
99 writer.Key(
"sigmoid_alpha");
101 writer.Key(
"global_bias");
111 template <
typename WriterType,
typename ThresholdType,
typename LeafOutputType>
112 void SerializeTreeToJSON(WriterType& writer,
const Tree<ThresholdType, LeafOutputType>& tree) {
113 writer.StartObject();
115 writer.Key(
"num_nodes");
116 writer.Int(tree.num_nodes);
117 writer.Key(
"leaf_vector");
118 WriteContiguousArray(writer, tree.leaf_vector_);
119 writer.Key(
"leaf_vector_offset");
120 WriteContiguousArray(writer, tree.leaf_vector_offset_);
121 writer.Key(
"matching_categories");
122 WriteContiguousArray(writer, tree.matching_categories_);
123 writer.Key(
"matching_categories_offset");
124 WriteContiguousArray(writer, tree.matching_categories_offset_);
127 for (std::size_t i = 0; i < tree.nodes_.Size(); ++i) {
128 WriteNode<WriterType, ThresholdType, LeafOutputType>(writer, tree.nodes_[i]);
135 TREELITE_CHECK_EQ(tree.nodes_.Size(), tree.num_nodes);
136 TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.leaf_vector_offset_.Size());
137 TREELITE_CHECK_EQ(tree.leaf_vector_offset_.Back(), tree.leaf_vector_.Size());
138 TREELITE_CHECK_EQ(tree.nodes_.Size() + 1, tree.matching_categories_offset_.Size());
139 TREELITE_CHECK_EQ(tree.matching_categories_offset_.Back(), tree.matching_categories_.Size());
142 template <
typename ThresholdType,
typename LeafOutputType>
143 void ModelImpl<ThresholdType, LeafOutputType>::SerializeToJSON(std::ostream& fo)
const {
144 rapidjson::OStreamWrapper os(fo);
145 rapidjson::Writer<rapidjson::OStreamWrapper> writer(os);
147 writer.StartObject();
149 writer.Key(
"num_feature");
150 writer.Int(num_feature);
151 writer.Key(
"task_type");
152 writer.Uint(static_cast<uint8_t>(task_type));
153 writer.Key(
"average_tree_output");
154 writer.Bool(average_tree_output);
155 writer.Key(
"task_param");
156 SerializeTaskParamToJSON(writer, task_param);
157 writer.Key(
"model_param");
158 SerializeModelParamToJSON(writer, param);
161 for (
const Tree<ThresholdType, LeafOutputType>& tree : trees) {
162 SerializeTreeToJSON(writer, tree);
169 template void ModelImpl<float, uint32_t>::SerializeToJSON(std::ostream& fo)
const;
170 template void ModelImpl<float, float>::SerializeToJSON(std::ostream& fo)
const;
171 template void ModelImpl<double, uint32_t>::SerializeToJSON(std::ostream& fo)
const;
172 template void ModelImpl<double, double>::SerializeToJSON(std::ostream& fo)
const;
SplitFeatureType split_type_
feature split type
bool gain_present_
whether gain_present_ field is present
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
bool data_count_present_
whether data_count_ field is present
model structure for tree ensemble
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
int32_t cleft_
pointer to left and right children
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
logging facility for Treelite
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
OutputType output_type
The type of output from each leaf node.
bool sum_hess_present_
whether sum_hess_ field is present
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Info info_
storage for leaf value or decision threshold
char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH]
name of prediction transform function