Treelite
protobuf.cc
Go to the documentation of this file.
1 
8 #include <dmlc/logging.h>
9 #include <dmlc/registry.h>
10 #include <dmlc/io.h>
11 #include <treelite/tree.h>
12 #include <queue>
13 
14 #ifdef TREELITE_PROTOBUF_SUPPORT
15 
16 #include "tree.pb.h"
17 
18 namespace {
19 
20 enum class NodeType : int8_t {
21  kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
22 };
23 
24 inline NodeType GetNodeType(const treelite_protobuf::Node& node) {
25  if (node.has_left_child()) { // node is non-leaf
26  CHECK(node.has_right_child());
27  CHECK(node.has_default_left());
28  CHECK(node.has_split_index());
29  CHECK(node.has_split_type());
30  CHECK(!node.has_leaf_value());
31  CHECK_EQ(node.leaf_vector_size(), 0);
32  const auto split_type = node.split_type();
33  if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
34  // numerical split
35  CHECK(node.has_op());
36  CHECK(node.has_threshold());
37  CHECK_EQ(node.left_categories_size(), 0);
38  return NodeType::kNumericalSplit;
39  } else { // categorical split
40  CHECK(!node.has_op());
41  CHECK(!node.has_threshold());
42  CHECK(node.has_missing_category_to_zero());
43  return NodeType::kCategoricalSplit;
44  }
45  } else { // node is leaf
46  CHECK(!node.has_right_child());
47  CHECK(!node.has_default_left());
48  CHECK(!node.has_split_index());
49  CHECK(!node.has_split_type());
50  CHECK(!node.has_op());
51  CHECK(!node.has_threshold());
52  CHECK(!node.has_gain());
53  CHECK(!node.has_missing_category_to_zero());
54  CHECK_EQ(node.left_categories_size(), 0);
55  if (node.has_leaf_value()) {
56  CHECK_EQ(node.leaf_vector_size(), 0);
57  return NodeType::kLeaf;
58  } else {
59  CHECK_GT(node.leaf_vector_size(), 0);
60  return NodeType::kLeafVector;
61  }
62  }
63 }
64 
65 } // anonymous namespace
66 
67 namespace treelite {
68 namespace frontend {
69 
70 DMLC_REGISTRY_FILE_TAG(protobuf);
71 
72 void LoadProtobufModel(const char* filename, Model* out) {
73  GOOGLE_PROTOBUF_VERIFY_VERSION;
74 
75  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
76  dmlc::istream is(fi.get());
77  treelite_protobuf::Model protomodel;
78  CHECK(protomodel.ParseFromIstream(&is)) << "Ill-formed Protocol Buffers file";
79 
80  Model model;
81  CHECK(protomodel.has_num_feature()) << "num_feature must exist";
82  const auto num_feature = protomodel.num_feature();
83  CHECK_LT(num_feature, std::numeric_limits<int>::max())
84  << "num_feature too big";
85  CHECK_GT(num_feature, 0) << "num_feature must be positive";
86  model.num_feature = static_cast<int>(protomodel.num_feature());
87 
88  CHECK(protomodel.has_num_output_group()) << "num_output_group must exist";
89  const auto num_output_group = protomodel.num_output_group();
90  CHECK_LT(num_output_group, std::numeric_limits<int>::max())
91  << "num_output_group too big";
92  CHECK_GT(num_output_group, 0) << "num_output_group must be positive";
93  model.num_output_group = static_cast<int>(protomodel.num_output_group());
94 
95  CHECK(protomodel.has_random_forest_flag())
96  << "random_forest_flag must exist";
97  model.random_forest_flag = protomodel.random_forest_flag();
98 
99  // extra parameters field
100  const auto& ep = protomodel.extra_params();
101  std::vector<std::pair<std::string, std::string>> cfg;
102  std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
103  InitParamAndCheck(&model.param, cfg);
104 
105  // flag to check consistent use of leaf vector
106  // 0: no leaf should use leaf vector
107  // 1: every leaf should use leaf vector
108  // -1: indeterminate
109  int8_t flag_leaf_vector = -1;
110 
111  const int ntree = protomodel.trees_size();
112  for (int i = 0; i < ntree; ++i) {
113  model.trees.emplace_back();
114  Tree& tree = model.trees.back();
115  tree.Init();
116 
117  CHECK(protomodel.trees(i).has_head());
118  // assign node ID's so that a breadth-wise traversal would yield
119  // the monotonic sequence 0, 1, 2, ...
120  std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
121  // (proto node, ID)
122  Q.push({protomodel.trees(i).head(), 0});
123  while (!Q.empty()) {
124  auto elem = Q.front(); Q.pop();
125  const treelite_protobuf::Node& node = elem.first;
126  int nid = elem.second;
127  const NodeType node_type = GetNodeType(node);
128  if (node_type == NodeType::kLeaf) { // leaf node with a scalar output
129  CHECK(flag_leaf_vector != 1)
130  << "Inconsistent use of leaf vector: if one leaf node does not use"
131  << "a leaf vector, *no other* leaf node can use a leaf vector";
132  flag_leaf_vector = 0; // now no leaf can use leaf vector
133 
134  tree.SetLeaf(nid, static_cast<tl_float>(node.leaf_value()));
135  } else if (node_type == NodeType::kLeafVector) {
136  // leaf node with vector output
137  CHECK(flag_leaf_vector != 0)
138  << "Inconsistent use of leaf vector: if one leaf node uses "
139  << "a leaf vector, *every* leaf node must use a leaf vector as well";
140  flag_leaf_vector = 1; // now every leaf must use leaf vector
141 
142  const int len = node.leaf_vector_size();
143  CHECK_EQ(len, model.num_output_group)
144  << "The length of leaf vector must be identical to the "
145  << "number of output groups";
146  std::vector<tl_float> leaf_vector(len);
147  for (int k = 0; k < len; ++k) {
148  leaf_vector[k] = static_cast<tl_float>(node.leaf_vector(k));
149  }
150  tree.SetLeafVector(nid, leaf_vector);
151  } else if (node_type == NodeType::kNumericalSplit) { // numerical split
152  const auto split_index = node.split_index();
153  const std::string& opname = node.op();
154  CHECK_LT(split_index, model.num_feature)
155  << "split_index must be between 0 and [num_feature] - 1.";
156  CHECK_GE(split_index, 0) << "split_index must be positive.";
157  CHECK_GT(optable.count(opname), 0) << "No operator `"
158  << opname << "\" exists";
159  tree.AddChilds(nid);
160  tree.SetNumericalSplit(nid,
161  static_cast<unsigned>(split_index),
162  static_cast<tl_float>(node.threshold()),
163  node.default_left(),
164  optable.at(opname));
165  Q.push({node.left_child(), tree.LeftChild(nid)});
166  Q.push({node.right_child(), tree.RightChild(nid)});
167  } else { // categorical split
168  const auto split_index = node.split_index();
169  CHECK_LT(split_index, model.num_feature)
170  << "split_index must be between 0 and [num_feature] - 1.";
171  CHECK_GE(split_index, 0) << "split_index must be positive.";
172  const int left_categories_size = node.left_categories_size();
173  std::vector<uint32_t> left_categories;
174  for (int k = 0; k < left_categories_size; ++k) {
175  const auto cat = node.left_categories(k);
176  CHECK(cat <= std::numeric_limits<uint32_t>::max());
177  left_categories.push_back(static_cast<uint32_t>(cat));
178  }
179  tree.AddChilds(nid);
180  tree.SetCategoricalSplit(nid,
181  static_cast<unsigned>(split_index),
182  node.default_left(),
183  node.missing_category_to_zero(),
184  left_categories);
185  Q.push({node.left_child(), tree.LeftChild(nid)});
186  Q.push({node.right_child(), tree.RightChild(nid)});
187  }
188  /* set node statistics */
189  if (node.has_data_count()) {
190  tree.SetDataCount(nid, static_cast<size_t>(node.data_count()));
191  }
192  if (node.has_sum_hess()) {
193  tree.SetSumHess(nid, node.sum_hess());
194  }
195  if (node.has_gain()) {
196  tree.SetGain(nid, node.gain());
197  }
198  }
199  }
200  if (flag_leaf_vector == 0) {
201  if (model.num_output_group > 1) {
202  // multi-class classification with gradient boosted trees
203  CHECK(!model.random_forest_flag)
204  << "To use a random forest for multi-class classification, each leaf "
205  << "node must output a leaf vector specifying a probability "
206  << "distribution";
207  CHECK_EQ(ntree % model.num_output_group, 0)
208  << "For multi-class classifiers with gradient boosted trees, the number "
209  << "of trees must be evenly divisible by the number of output groups";
210  }
211  } else if (flag_leaf_vector == 1) {
212  // multiclass classification with a random forest
213  CHECK(model.random_forest_flag)
214  << "In multi-class classifiers with gradient boosted trees, each leaf "
215  << "node must output a single floating-point value.";
216  } else {
217  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
218  }
219  *out = std::move(model);
220 }
221 
222 void ExportProtobufModel(const char* filename, const Model& model) {
223  GOOGLE_PROTOBUF_VERIFY_VERSION;
224 
225  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "w"));
226  dmlc::ostream os(fi.get());
227  treelite_protobuf::Model protomodel;
228 
229  protomodel.set_num_feature(
230  static_cast<google::protobuf::int32>(model.num_feature));
231 
232  protomodel.set_num_output_group(
233  static_cast<google::protobuf::int32>(model.num_output_group));
234 
235  protomodel.set_random_forest_flag(model.random_forest_flag);
236 
237  // extra parameters field
238  for (const auto& kv : model.param.__DICT__()) {
239  (*protomodel.mutable_extra_params())[kv.first] = kv.second;
240  }
241 
242  // flag to check consistent use of leaf vector
243  // 0: no leaf should use leaf vector
244  // 1: every leaf should use leaf vector
245  // -1: indeterminate
246  int8_t flag_leaf_vector = -1;
247 
248  const int ntree = model.trees.size();
249  for (int i = 0; i < ntree; ++i) {
250  const Tree& tree = model.trees[i];
251  treelite_protobuf::Tree* proto_tree = protomodel.add_trees();
252 
253  std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
254  Q.push({0, proto_tree->mutable_head()});
255  while (!Q.empty()) {
256  auto elem = Q.front(); Q.pop();
257  const int nid = elem.first;
258  treelite_protobuf::Node* proto_node = elem.second;
259  if (tree.IsLeaf(nid)) { // leaf node
260  if (tree.HasLeafVector(nid)) { // leaf node with vector output
261  CHECK(flag_leaf_vector != 0)
262  << "Inconsistent use of leaf vector: if one leaf node uses "
263  << "a leaf vector, *every* leaf node must use a leaf vector as well";
264  flag_leaf_vector = 1; // now every leaf must use leaf vector
265 
266  const auto& leaf_vector = tree.LeafVector(nid);
267  CHECK_EQ(leaf_vector.size(), model.num_output_group)
268  << "The length of leaf vector must be identical to the "
269  << "number of output groups";
270  for (tl_float e : leaf_vector) {
271  proto_node->add_leaf_vector(static_cast<float>(e));
272  }
273  CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
274  } else { // leaf node with scalar output
275  CHECK(flag_leaf_vector != 1)
276  << "Inconsistent use of leaf vector: if one leaf node does not use"
277  << "a leaf vector, *no other* leaf node can use a leaf vector";
278  flag_leaf_vector = 0; // now no leaf can use leaf vector
279 
280  proto_node->set_leaf_value(static_cast<float>(tree.LeafValue(nid)));
281  }
282  } else if (tree.SplitType(nid) == SplitFeatureType::kNumerical) {
283  // numerical split
284  const unsigned split_index = tree.SplitIndex(nid);
285  const tl_float threshold = tree.Threshold(nid);
286  const bool default_left = tree.DefaultLeft(nid);
287  const Operator op = tree.ComparisonOp(nid);
288 
289  proto_node->set_default_left(default_left);
290  proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
291  proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
292  proto_node->set_op(OpName(op));
293  proto_node->set_threshold(static_cast<float>(threshold));
294  Q.push({tree.LeftChild(nid), proto_node->mutable_left_child()});
295  Q.push({tree.RightChild(nid), proto_node->mutable_right_child()});
296  } else { // categorical split
297  const unsigned split_index = tree.SplitIndex(nid);
298  const auto& left_categories = tree.LeftCategories(nid);
299  const bool default_left = tree.DefaultLeft(nid);
300  const bool missing_category_to_zero = tree.MissingCategoryToZero(nid);
301 
302  proto_node->set_default_left(default_left);
303  proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
304  proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
305  proto_node->set_missing_category_to_zero(missing_category_to_zero);
306  for (auto e : left_categories) {
307  proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
308  }
309  Q.push({tree.LeftChild(nid), proto_node->mutable_left_child()});
310  Q.push({tree.RightChild(nid), proto_node->mutable_right_child()});
311  }
312  /* set node statistics */
313  if (tree.HasDataCount(nid)) {
314  proto_node->set_data_count(
315  static_cast<google::protobuf::uint64>(tree.DataCount(nid)));
316  }
317  if (tree.HasSumHess(nid)) {
318  proto_node->set_sum_hess(tree.SumHess(nid));
319  }
320  if (tree.HasGain(nid)) {
321  proto_node->set_gain(tree.Gain(nid));
322  }
323  }
324  }
325  CHECK(protomodel.SerializeToOstream(&os))
326  << "Failed to write Protocol Buffers file";
327  os.set_stream(nullptr);
328 }
329 
330 } // namespace frontend
331 } // namespace treelite
332 
333 #else // TREELITE_PROTOBUF_SUPPORT
334 
335 namespace treelite {
336 namespace frontend {
337 
338 DMLC_REGISTRY_FILE_TAG(protobuf);
339 
340 Model LoadProtobufModel(const char* filename) {
341  LOG(FATAL) << "Treelite was not compiled with Protobuf!";
342  return Model(); // should not reach here
343 }
344 
345 void ExportProtobufModel(const char* filename, const Model& model) {
346  LOG(FATAL) << "Treelite was not compiled with Protobuf!";
347 }
348 
349 } // namespace frontend
350 } // namespace treelite
351 
352 #endif // TREELITE_PROTOBUF_SUPPORT
void ExportProtobufModel(const char *filename, const Model &model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
Definition: protobuf.cc:345
void LoadProtobufModel(const char *filename, Model *out)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:40
model structure for tree ensemble
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Definition: optable.cc:12
Operator
comparison operators
Definition: base.h:24