9 #include <dmlc/logging.h> 12 #ifdef TREELITE_PROTOBUF_SUPPORT 18 enum class NodeType : int8_t {
19 kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
22 inline NodeType GetNodeType(
const treelite_protobuf::Node& node) {
23 if (node.has_left_child()) {
24 CHECK(node.has_right_child());
25 CHECK(node.has_default_left());
26 CHECK(node.has_split_index());
27 CHECK(node.has_split_type());
28 CHECK(!node.has_leaf_value());
29 CHECK_EQ(node.leaf_vector_size(), 0);
30 const auto split_type = node.split_type();
31 if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
34 CHECK(node.has_threshold());
35 CHECK_EQ(node.left_categories_size(), 0);
36 return NodeType::kNumericalSplit;
38 CHECK(!node.has_op());
39 CHECK(!node.has_threshold());
40 CHECK(node.has_missing_category_to_zero());
41 return NodeType::kCategoricalSplit;
44 CHECK(!node.has_right_child());
45 CHECK(!node.has_default_left());
46 CHECK(!node.has_split_index());
47 CHECK(!node.has_split_type());
48 CHECK(!node.has_op());
49 CHECK(!node.has_threshold());
50 CHECK(!node.has_gain());
51 CHECK(!node.has_missing_category_to_zero());
52 CHECK_EQ(node.left_categories_size(), 0);
53 if (node.has_leaf_value()) {
54 CHECK_EQ(node.leaf_vector_size(), 0);
55 return NodeType::kLeaf;
57 CHECK_GT(node.leaf_vector_size(), 0);
58 return NodeType::kLeafVector;
68 DMLC_REGISTRY_FILE_TAG(protobuf);
71 GOOGLE_PROTOBUF_VERIFY_VERSION;
73 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
74 dmlc::istream is(fi.get());
75 treelite_protobuf::Model protomodel;
76 CHECK(protomodel.ParseFromIstream(&is)) <<
"Ill-formed Protocol Buffers file";
79 CHECK(protomodel.has_num_feature()) <<
"num_feature must exist";
80 const auto num_feature = protomodel.num_feature();
81 CHECK_LT(num_feature, std::numeric_limits<int>::max())
82 <<
"num_feature too big";
83 CHECK_GT(num_feature, 0) <<
"num_feature must be positive";
84 model.num_feature =
static_cast<int>(protomodel.num_feature());
86 CHECK(protomodel.has_num_output_group()) <<
"num_output_group must exist";
87 const auto num_output_group = protomodel.num_output_group();
88 CHECK_LT(num_output_group, std::numeric_limits<int>::max())
89 <<
"num_output_group too big";
90 CHECK_GT(num_output_group, 0) <<
"num_output_group must be positive";
91 model.num_output_group =
static_cast<int>(protomodel.num_output_group());
93 CHECK(protomodel.has_random_forest_flag())
94 <<
"random_forest_flag must exist";
95 model.random_forest_flag = protomodel.random_forest_flag();
98 const auto& ep = protomodel.extra_params();
99 std::vector<std::pair<std::string, std::string>> cfg;
100 std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
101 InitParamAndCheck(&model.param, cfg);
107 int8_t flag_leaf_vector = -1;
109 const int ntree = protomodel.trees_size();
110 for (
int i = 0; i < ntree; ++i) {
111 model.trees.emplace_back();
112 Tree& tree = model.trees.back();
115 CHECK(protomodel.trees(i).has_head());
118 std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
120 Q.push({protomodel.trees(i).head(), 0});
122 auto elem = Q.front(); Q.pop();
123 const treelite_protobuf::Node& node = elem.first;
124 int id = elem.second;
125 const NodeType node_type = GetNodeType(node);
126 if (node_type == NodeType::kLeaf) {
127 CHECK(flag_leaf_vector != 1)
128 <<
"Inconsistent use of leaf vector: if one leaf node does not use" 129 <<
"a leaf vector, *no other* leaf node can use a leaf vector";
130 flag_leaf_vector = 0;
132 tree[id].set_leaf(static_cast<tl_float>(node.leaf_value()));
133 }
else if (node_type == NodeType::kLeafVector) {
135 CHECK(flag_leaf_vector != 0)
136 <<
"Inconsistent use of leaf vector: if one leaf node uses " 137 <<
"a leaf vector, *every* leaf node must use a leaf vector as well";
138 flag_leaf_vector = 1;
140 const int len = node.leaf_vector_size();
141 CHECK_EQ(len, model.num_output_group)
142 <<
"The length of leaf vector must be identical to the " 143 <<
"number of output groups";
144 std::vector<tl_float> leaf_vector(len);
145 for (
int i = 0; i < len; ++i) {
146 leaf_vector[i] =
static_cast<tl_float>(node.leaf_vector(i));
148 tree[id].set_leaf_vector(leaf_vector);
149 }
else if (node_type == NodeType::kNumericalSplit) {
150 const auto split_index = node.split_index();
151 const std::string opname = node.op();
152 CHECK_LT(split_index, model.num_feature)
153 <<
"split_index must be between 0 and [num_feature] - 1.";
154 CHECK_GE(split_index, 0) <<
"split_index must be positive.";
155 CHECK_GT(
optable.count(opname), 0) <<
"No operator `" 156 << opname <<
"\" exists";
158 tree[id].set_numerical_split(static_cast<unsigned>(split_index),
159 static_cast<tl_float>(node.threshold()),
162 Q.push({node.left_child(), tree[id].cleft()});
163 Q.push({node.right_child(), tree[id].cright()});
165 const auto split_index = node.split_index();
166 CHECK_LT(split_index, model.num_feature)
167 <<
"split_index must be between 0 and [num_feature] - 1.";
168 CHECK_GE(split_index, 0) <<
"split_index must be positive.";
169 const int left_categories_size = node.left_categories_size();
170 std::vector<uint32_t> left_categories;
171 for (
int i = 0; i < left_categories_size; ++i) {
172 const auto cat = node.left_categories(i);
173 CHECK(cat <= std::numeric_limits<uint32_t>::max());
174 left_categories.push_back(static_cast<uint32_t>(cat));
177 tree[id].set_categorical_split(static_cast<unsigned>(split_index),
179 node.missing_category_to_zero(),
181 Q.push({node.left_child(), tree[id].cleft()});
182 Q.push({node.right_child(), tree[id].cright()});
185 if (node.has_data_count()) {
186 tree[id].set_data_count(static_cast<size_t>(node.data_count()));
188 if (node.has_sum_hess()) {
189 tree[id].set_sum_hess(node.sum_hess());
191 if (node.has_gain()) {
192 tree[id].set_gain(node.gain());
196 if (flag_leaf_vector == 0) {
197 if (model.num_output_group > 1) {
199 CHECK(!model.random_forest_flag)
200 <<
"To use a random forest for multi-class classification, each leaf " 201 <<
"node must output a leaf vector specifying a probability " 203 CHECK_EQ(ntree % model.num_output_group, 0)
204 <<
"For multi-class classifiers with gradient boosted trees, the number " 205 <<
"of trees must be evenly divisible by the number of output groups";
207 }
else if (flag_leaf_vector == 1) {
209 CHECK(model.random_forest_flag)
210 <<
"In multi-class classifiers with gradient boosted trees, each leaf " 211 <<
"node must output a single floating-point value.";
213 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
219 GOOGLE_PROTOBUF_VERIFY_VERSION;
221 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"w"));
222 dmlc::ostream os(fi.get());
223 treelite_protobuf::Model protomodel;
225 protomodel.set_num_feature(
226 static_cast<google::protobuf::int32>(model.num_feature));
228 protomodel.set_num_output_group(
229 static_cast<google::protobuf::int32>(model.num_output_group));
231 protomodel.set_random_forest_flag(model.random_forest_flag);
234 for (
const auto& kv : model.param.__DICT__()) {
235 (*protomodel.mutable_extra_params())[kv.first] = kv.second;
242 int8_t flag_leaf_vector = -1;
244 const int ntree = model.trees.size();
245 for (
int i = 0; i < ntree; ++i) {
246 const Tree& tree = model.trees[i];
247 treelite_protobuf::Tree* proto_tree = protomodel.add_trees();
249 std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
250 Q.push({0, proto_tree->mutable_head()});
252 auto elem = Q.front(); Q.pop();
253 const int nid = elem.first;
254 treelite_protobuf::Node* proto_node = elem.second;
255 if (tree[nid].is_leaf()) {
256 if (tree[nid].has_leaf_vector()) {
257 CHECK(flag_leaf_vector != 0)
258 <<
"Inconsistent use of leaf vector: if one leaf node uses " 259 <<
"a leaf vector, *every* leaf node must use a leaf vector as well";
260 flag_leaf_vector = 1;
262 const auto& leaf_vector = tree[nid].leaf_vector();
263 CHECK_EQ(leaf_vector.size(), model.num_output_group)
264 <<
"The length of leaf vector must be identical to the " 265 <<
"number of output groups";
266 for (tl_float e : leaf_vector) {
267 proto_node->add_leaf_vector(static_cast<double>(e));
269 CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
271 CHECK(flag_leaf_vector != 1)
272 <<
"Inconsistent use of leaf vector: if one leaf node does not use" 273 <<
"a leaf vector, *no other* leaf node can use a leaf vector";
274 flag_leaf_vector = 0;
276 proto_node->set_leaf_value(static_cast<double>(tree[nid].leaf_value()));
278 }
else if (tree[nid].split_type() == SplitFeatureType::kNumerical) {
280 const unsigned split_index = tree[nid].split_index();
281 const tl_float threshold = tree[nid].threshold();
282 const bool default_left = tree[nid].default_left();
283 const Operator op = tree[nid].comparison_op();
285 proto_node->set_default_left(default_left);
286 proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
287 proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
288 proto_node->set_op(
OpName(op));
289 proto_node->set_threshold(static_cast<double>(threshold));
290 Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
291 Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
293 const unsigned split_index = tree[nid].split_index();
294 const auto& left_categories = tree[nid].left_categories();
295 const bool default_left = tree[nid].default_left();
296 const bool missing_category_to_zero = tree[nid].missing_category_to_zero();
298 proto_node->set_default_left(default_left);
299 proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
300 proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
301 proto_node->set_missing_category_to_zero(missing_category_to_zero);
302 for (
auto e : left_categories) {
303 proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
305 Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
306 Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
309 if (tree[nid].has_data_count()) {
310 proto_node->set_data_count(
311 static_cast<google::protobuf::uint64>(tree[nid].data_count()));
313 if (tree[nid].has_sum_hess()) {
314 proto_node->set_sum_hess(tree[nid].sum_hess());
316 if (tree[nid].has_gain()) {
317 proto_node->set_gain(tree[nid].gain());
321 CHECK(protomodel.SerializeToOstream(&os))
322 <<
"Failed to write Protocol Buffers file";
323 os.set_stream(
nullptr);
329 #else // TREELITE_PROTOBUF_SUPPORT 334 DMLC_REGISTRY_FILE_TAG(protobuf);
337 LOG(FATAL) <<
"Treelite was not compiled with Protobuf!";
342 LOG(FATAL) <<
"Treelite was not compiled with Protobuf!";
348 #endif // TREELITE_PROTOBUF_SUPPORT void ExportProtobufModel(const char *filename, const Model &model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
thin wrapper for tree ensemble model
std::string OpName(Operator op)
get string representation of comparsion operator
Model LoadProtobufModel(const char *filename)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
double tl_float
float type to be used internally
Operator
comparison operators