8 #include <dmlc/logging.h> 9 #include <dmlc/registry.h> 14 #ifdef TREELITE_PROTOBUF_SUPPORT 20 enum class NodeType : int8_t {
21 kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
24 inline NodeType GetNodeType(
const treelite_protobuf::Node& node) {
25 if (node.has_left_child()) {
26 CHECK(node.has_right_child());
27 CHECK(node.has_default_left());
28 CHECK(node.has_split_index());
29 CHECK(node.has_split_type());
30 CHECK(!node.has_leaf_value());
31 CHECK_EQ(node.leaf_vector_size(), 0);
32 const auto split_type = node.split_type();
33 if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
36 CHECK(node.has_threshold());
37 CHECK_EQ(node.left_categories_size(), 0);
38 return NodeType::kNumericalSplit;
40 CHECK(!node.has_op());
41 CHECK(!node.has_threshold());
42 CHECK(node.has_missing_category_to_zero());
43 return NodeType::kCategoricalSplit;
46 CHECK(!node.has_right_child());
47 CHECK(!node.has_default_left());
48 CHECK(!node.has_split_index());
49 CHECK(!node.has_split_type());
50 CHECK(!node.has_op());
51 CHECK(!node.has_threshold());
52 CHECK(!node.has_gain());
53 CHECK(!node.has_missing_category_to_zero());
54 CHECK_EQ(node.left_categories_size(), 0);
55 if (node.has_leaf_value()) {
56 CHECK_EQ(node.leaf_vector_size(), 0);
57 return NodeType::kLeaf;
59 CHECK_GT(node.leaf_vector_size(), 0);
60 return NodeType::kLeafVector;
70 DMLC_REGISTRY_FILE_TAG(protobuf);
73 GOOGLE_PROTOBUF_VERIFY_VERSION;
75 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
76 dmlc::istream is(fi.get());
77 treelite_protobuf::Model protomodel;
78 CHECK(protomodel.ParseFromIstream(&is)) <<
"Ill-formed Protocol Buffers file";
81 CHECK(protomodel.has_num_feature()) <<
"num_feature must exist";
82 const auto num_feature = protomodel.num_feature();
83 CHECK_LT(num_feature, std::numeric_limits<int>::max())
84 <<
"num_feature too big";
85 CHECK_GT(num_feature, 0) <<
"num_feature must be positive";
86 model.num_feature =
static_cast<int>(protomodel.num_feature());
88 CHECK(protomodel.has_num_output_group()) <<
"num_output_group must exist";
89 const auto num_output_group = protomodel.num_output_group();
90 CHECK_LT(num_output_group, std::numeric_limits<int>::max())
91 <<
"num_output_group too big";
92 CHECK_GT(num_output_group, 0) <<
"num_output_group must be positive";
93 model.num_output_group =
static_cast<int>(protomodel.num_output_group());
95 CHECK(protomodel.has_random_forest_flag())
96 <<
"random_forest_flag must exist";
97 model.random_forest_flag = protomodel.random_forest_flag();
100 const auto& ep = protomodel.extra_params();
101 std::vector<std::pair<std::string, std::string>> cfg;
102 std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
103 InitParamAndCheck(&model.param, cfg);
109 int8_t flag_leaf_vector = -1;
111 const int ntree = protomodel.trees_size();
112 for (
int i = 0; i < ntree; ++i) {
113 model.trees.emplace_back();
114 Tree& tree = model.trees.back();
117 CHECK(protomodel.trees(i).has_head());
120 std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
122 Q.push({protomodel.trees(i).head(), 0});
124 auto elem = Q.front(); Q.pop();
125 const treelite_protobuf::Node& node = elem.first;
126 int nid = elem.second;
127 const NodeType node_type = GetNodeType(node);
128 if (node_type == NodeType::kLeaf) {
129 CHECK(flag_leaf_vector != 1)
130 <<
"Inconsistent use of leaf vector: if one leaf node does not use" 131 <<
"a leaf vector, *no other* leaf node can use a leaf vector";
132 flag_leaf_vector = 0;
134 tree.SetLeaf(nid, static_cast<tl_float>(node.leaf_value()));
135 }
else if (node_type == NodeType::kLeafVector) {
137 CHECK(flag_leaf_vector != 0)
138 <<
"Inconsistent use of leaf vector: if one leaf node uses " 139 <<
"a leaf vector, *every* leaf node must use a leaf vector as well";
140 flag_leaf_vector = 1;
142 const int len = node.leaf_vector_size();
143 CHECK_EQ(len, model.num_output_group)
144 <<
"The length of leaf vector must be identical to the " 145 <<
"number of output groups";
146 std::vector<tl_float> leaf_vector(len);
147 for (
int k = 0; k < len; ++k) {
148 leaf_vector[k] =
static_cast<tl_float>(node.leaf_vector(k));
150 tree.SetLeafVector(nid, leaf_vector);
151 }
else if (node_type == NodeType::kNumericalSplit) {
152 const auto split_index = node.split_index();
153 const std::string& opname = node.op();
154 CHECK_LT(split_index, model.num_feature)
155 <<
"split_index must be between 0 and [num_feature] - 1.";
156 CHECK_GE(split_index, 0) <<
"split_index must be positive.";
157 CHECK_GT(
optable.count(opname), 0) <<
"No operator `" 158 << opname <<
"\" exists";
160 tree.SetNumericalSplit(nid,
161 static_cast<unsigned>(split_index),
162 static_cast<tl_float>(node.threshold()),
165 Q.push({node.left_child(), tree.LeftChild(nid)});
166 Q.push({node.right_child(), tree.RightChild(nid)});
168 const auto split_index = node.split_index();
169 CHECK_LT(split_index, model.num_feature)
170 <<
"split_index must be between 0 and [num_feature] - 1.";
171 CHECK_GE(split_index, 0) <<
"split_index must be positive.";
172 const int left_categories_size = node.left_categories_size();
173 std::vector<uint32_t> left_categories;
174 for (
int k = 0; k < left_categories_size; ++k) {
175 const auto cat = node.left_categories(k);
176 CHECK(cat <= std::numeric_limits<uint32_t>::max());
177 left_categories.push_back(static_cast<uint32_t>(cat));
180 tree.SetCategoricalSplit(nid,
181 static_cast<unsigned>(split_index),
183 node.missing_category_to_zero(),
185 Q.push({node.left_child(), tree.LeftChild(nid)});
186 Q.push({node.right_child(), tree.RightChild(nid)});
189 if (node.has_data_count()) {
190 tree.SetDataCount(nid, static_cast<size_t>(node.data_count()));
192 if (node.has_sum_hess()) {
193 tree.SetSumHess(nid, node.sum_hess());
195 if (node.has_gain()) {
196 tree.SetGain(nid, node.gain());
200 if (flag_leaf_vector == 0) {
201 if (model.num_output_group > 1) {
203 CHECK(!model.random_forest_flag)
204 <<
"To use a random forest for multi-class classification, each leaf " 205 <<
"node must output a leaf vector specifying a probability " 207 CHECK_EQ(ntree % model.num_output_group, 0)
208 <<
"For multi-class classifiers with gradient boosted trees, the number " 209 <<
"of trees must be evenly divisible by the number of output groups";
211 }
else if (flag_leaf_vector == 1) {
213 CHECK(model.random_forest_flag)
214 <<
"In multi-class classifiers with gradient boosted trees, each leaf " 215 <<
"node must output a single floating-point value.";
217 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
219 *out = std::move(model);
223 GOOGLE_PROTOBUF_VERIFY_VERSION;
225 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"w"));
226 dmlc::ostream os(fi.get());
227 treelite_protobuf::Model protomodel;
229 protomodel.set_num_feature(
230 static_cast<google::protobuf::int32>(model.num_feature));
232 protomodel.set_num_output_group(
233 static_cast<google::protobuf::int32>(model.num_output_group));
235 protomodel.set_random_forest_flag(model.random_forest_flag);
238 for (
const auto& kv : model.param.__DICT__()) {
239 (*protomodel.mutable_extra_params())[kv.first] = kv.second;
246 int8_t flag_leaf_vector = -1;
248 const int ntree = model.trees.size();
249 for (
int i = 0; i < ntree; ++i) {
250 const Tree& tree = model.trees[i];
251 treelite_protobuf::Tree* proto_tree = protomodel.add_trees();
253 std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
254 Q.push({0, proto_tree->mutable_head()});
256 auto elem = Q.front(); Q.pop();
257 const int nid = elem.first;
258 treelite_protobuf::Node* proto_node = elem.second;
259 if (tree.IsLeaf(nid)) {
260 if (tree.HasLeafVector(nid)) {
261 CHECK(flag_leaf_vector != 0)
262 <<
"Inconsistent use of leaf vector: if one leaf node uses " 263 <<
"a leaf vector, *every* leaf node must use a leaf vector as well";
264 flag_leaf_vector = 1;
266 const auto& leaf_vector = tree.LeafVector(nid);
267 CHECK_EQ(leaf_vector.size(), model.num_output_group)
268 <<
"The length of leaf vector must be identical to the " 269 <<
"number of output groups";
270 for (tl_float e : leaf_vector) {
271 proto_node->add_leaf_vector(static_cast<float>(e));
273 CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
275 CHECK(flag_leaf_vector != 1)
276 <<
"Inconsistent use of leaf vector: if one leaf node does not use" 277 <<
"a leaf vector, *no other* leaf node can use a leaf vector";
278 flag_leaf_vector = 0;
280 proto_node->set_leaf_value(static_cast<float>(tree.LeafValue(nid)));
282 }
else if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
284 const unsigned split_index = tree.SplitIndex(nid);
285 const tl_float threshold = tree.Threshold(nid);
286 const bool default_left = tree.DefaultLeft(nid);
287 const Operator op = tree.ComparisonOp(nid);
289 proto_node->set_default_left(default_left);
290 proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
291 proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
292 proto_node->set_op(
OpName(op));
293 proto_node->set_threshold(static_cast<float>(threshold));
294 Q.push({tree.LeftChild(nid), proto_node->mutable_left_child()});
295 Q.push({tree.RightChild(nid), proto_node->mutable_right_child()});
297 const unsigned split_index = tree.SplitIndex(nid);
298 const auto& left_categories = tree.LeftCategories(nid);
299 const bool default_left = tree.DefaultLeft(nid);
300 const bool missing_category_to_zero = tree.MissingCategoryToZero(nid);
302 proto_node->set_default_left(default_left);
303 proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
304 proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
305 proto_node->set_missing_category_to_zero(missing_category_to_zero);
306 for (
auto e : left_categories) {
307 proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
309 Q.push({tree.LeftChild(nid), proto_node->mutable_left_child()});
310 Q.push({tree.RightChild(nid), proto_node->mutable_right_child()});
313 if (tree.HasDataCount(nid)) {
314 proto_node->set_data_count(
315 static_cast<google::protobuf::uint64>(tree.DataCount(nid)));
317 if (tree.HasSumHess(nid)) {
318 proto_node->set_sum_hess(tree.SumHess(nid));
320 if (tree.HasGain(nid)) {
321 proto_node->set_gain(tree.Gain(nid));
325 CHECK(protomodel.SerializeToOstream(&os))
326 <<
"Failed to write Protocol Buffers file";
327 os.set_stream(
nullptr);
333 #else // TREELITE_PROTOBUF_SUPPORT 338 DMLC_REGISTRY_FILE_TAG(protobuf);
341 LOG(FATAL) <<
"Treelite was not compiled with Protobuf!";
346 LOG(FATAL) <<
"Treelite was not compiled with Protobuf!";
352 #endif // TREELITE_PROTOBUF_SUPPORT void ExportProtobufModel(const char *filename, const Model &model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
void LoadProtobufModel(const char *filename, Model *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
std::string OpName(Operator op)
get string representation of comparsion operator
model structure for tree ensemble
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Operator
comparison operators