treelite
protobuf.cc
Go to the documentation of this file.
1 
8 #include <dmlc/logging.h>
9 #include <treelite/tree.h>
10 #include <queue>
11 
12 #ifdef TREELITE_PROTOBUF_SUPPORT
13 
14 #include "tree.pb.h"
15 
16 namespace {
17 
18 enum class NodeType : int8_t {
19  kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
20 };
21 
22 inline NodeType GetNodeType(const treelite_protobuf::Node& node) {
23  if (node.has_left_child()) { // node is non-leaf
24  CHECK(node.has_right_child());
25  CHECK(node.has_default_left());
26  CHECK(node.has_split_index());
27  CHECK(node.has_split_type());
28  CHECK(!node.has_leaf_value());
29  CHECK_EQ(node.leaf_vector_size(), 0);
30  const auto split_type = node.split_type();
31  if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
32  // numerical split
33  CHECK(node.has_op());
34  CHECK(node.has_threshold());
35  CHECK_EQ(node.left_categories_size(), 0);
36  return NodeType::kNumericalSplit;
37  } else { // categorical split
38  CHECK(!node.has_op());
39  CHECK(!node.has_threshold());
40  return NodeType::kCategoricalSplit;
41  }
42  } else { // node is leaf
43  CHECK(!node.has_right_child());
44  CHECK(!node.has_default_left());
45  CHECK(!node.has_split_index());
46  CHECK(!node.has_split_type());
47  CHECK(!node.has_op());
48  CHECK(!node.has_threshold());
49  CHECK(!node.has_gain());
50  CHECK_EQ(node.left_categories_size(), 0);
51  if (node.has_leaf_value()) {
52  CHECK_EQ(node.leaf_vector_size(), 0);
53  return NodeType::kLeaf;
54  } else {
55  CHECK_GT(node.leaf_vector_size(), 0);
56  return NodeType::kLeafVector;
57  }
58  }
59 }
60 
61 } // anonymous namespace
62 
63 namespace treelite {
64 namespace frontend {
65 
66 DMLC_REGISTRY_FILE_TAG(protobuf);
67 
68 Model LoadProtobufModel(const char* filename) {
69  GOOGLE_PROTOBUF_VERIFY_VERSION;
70 
71  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
72  dmlc::istream is(fi.get());
73  treelite_protobuf::Model protomodel;
74  CHECK(protomodel.ParseFromIstream(&is)) << "Ill-formed Protocol Buffers file";
75 
76  Model model;
77  CHECK(protomodel.has_num_feature()) << "num_feature must exist";
78  const auto num_feature = protomodel.num_feature();
79  CHECK_LT(num_feature, std::numeric_limits<int>::max())
80  << "num_feature too big";
81  CHECK_GT(num_feature, 0) << "num_feature must be positive";
82  model.num_feature = static_cast<int>(protomodel.num_feature());
83 
84  CHECK(protomodel.has_num_output_group()) << "num_output_group must exist";
85  const auto num_output_group = protomodel.num_output_group();
86  CHECK_LT(num_output_group, std::numeric_limits<int>::max())
87  << "num_output_group too big";
88  CHECK_GT(num_output_group, 0) << "num_output_group must be positive";
89  model.num_output_group = static_cast<int>(protomodel.num_output_group());
90 
91  CHECK(protomodel.has_random_forest_flag())
92  << "random_forest_flag must exist";
93  model.random_forest_flag = protomodel.random_forest_flag();
94 
95  // extra parameters field
96  const auto& ep = protomodel.extra_params();
97  std::vector<std::pair<std::string, std::string>> cfg;
98  std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
99  InitParamAndCheck(&model.param, cfg);
100 
101  // flag to check consistent use of leaf vector
102  // 0: no leaf should use leaf vector
103  // 1: every leaf should use leaf vector
104  // -1: indeterminate
105  int8_t flag_leaf_vector = -1;
106 
107  const int ntree = protomodel.trees_size();
108  for (int i = 0; i < ntree; ++i) {
109  model.trees.emplace_back();
110  Tree& tree = model.trees.back();
111  tree.Init();
112 
113  CHECK(protomodel.trees(i).has_head());
114  // assign node ID's so that a breadth-wise traversal would yield
115  // the monotonic sequence 0, 1, 2, ...
116  std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
117  // (proto node, ID)
118  Q.push({protomodel.trees(i).head(), 0});
119  while (!Q.empty()) {
120  auto elem = Q.front(); Q.pop();
121  const treelite_protobuf::Node& node = elem.first;
122  int id = elem.second;
123  const NodeType node_type = GetNodeType(node);
124  if (node_type == NodeType::kLeaf) { // leaf node with a scalar output
125  CHECK(flag_leaf_vector != 1)
126  << "Inconsistent use of leaf vector: if one leaf node does not use"
127  << "a leaf vector, *no other* leaf node can use a leaf vector";
128  flag_leaf_vector = 0; // now no leaf can use leaf vector
129 
130  tree[id].set_leaf(static_cast<tl_float>(node.leaf_value()));
131  } else if (node_type == NodeType::kLeafVector) {
132  // leaf node with vector output
133  CHECK(flag_leaf_vector != 0)
134  << "Inconsistent use of leaf vector: if one leaf node uses "
135  << "a leaf vector, *every* leaf node must use a leaf vector as well";
136  flag_leaf_vector = 1; // now every leaf must use leaf vector
137 
138  const int len = node.leaf_vector_size();
139  CHECK_EQ(len, model.num_output_group)
140  << "The length of leaf vector must be identical to the "
141  << "number of output groups";
142  std::vector<tl_float> leaf_vector(len);
143  for (int i = 0; i < len; ++i) {
144  leaf_vector[i] = static_cast<tl_float>(node.leaf_vector(i));
145  }
146  tree[id].set_leaf_vector(leaf_vector);
147  } else if (node_type == NodeType::kNumericalSplit) { // numerical split
148  const auto split_index = node.split_index();
149  const std::string opname = node.op();
150  CHECK_LT(split_index, model.num_feature)
151  << "split_index must be between 0 and [num_feature] - 1.";
152  CHECK_GE(split_index, 0) << "split_index must be positive.";
153  CHECK_GT(optable.count(opname), 0) << "No operator `"
154  << opname << "\" exists";
155  tree.AddChilds(id);
156  tree[id].set_numerical_split(static_cast<unsigned>(split_index),
157  static_cast<tl_float>(node.threshold()),
158  node.default_left(),
159  optable.at(opname.c_str()));
160  Q.push({node.left_child(), tree[id].cleft()});
161  Q.push({node.right_child(), tree[id].cright()});
162  } else { // categorical split
163  const auto split_index = node.split_index();
164  CHECK_LT(split_index, model.num_feature)
165  << "split_index must be between 0 and [num_feature] - 1.";
166  CHECK_GE(split_index, 0) << "split_index must be positive.";
167  const int left_categories_size = node.left_categories_size();
168  std::vector<uint32_t> left_categories;
169  for (int i = 0; i < left_categories_size; ++i) {
170  const auto cat = node.left_categories(i);
171  CHECK(cat <= std::numeric_limits<uint32_t>::max());
172  left_categories.push_back(static_cast<uint32_t>(cat));
173  }
174  tree.AddChilds(id);
175  tree[id].set_categorical_split(static_cast<unsigned>(split_index),
176  node.default_left(),
177  left_categories);
178  Q.push({node.left_child(), tree[id].cleft()});
179  Q.push({node.right_child(), tree[id].cright()});
180  }
181  /* set node statistics */
182  if (node.has_data_count()) {
183  tree[id].set_data_count(static_cast<size_t>(node.data_count()));
184  }
185  if (node.has_sum_hess()) {
186  tree[id].set_sum_hess(node.sum_hess());
187  }
188  if (node.has_gain()) {
189  tree[id].set_gain(node.gain());
190  }
191  }
192  }
193  if (flag_leaf_vector == 0) {
194  if (model.num_output_group > 1) {
195  // multiclass classification with gradient boosted trees
196  CHECK(!model.random_forest_flag)
197  << "To use a random forest for multi-class classification, each leaf "
198  << "node must output a leaf vector specifying a probability "
199  << "distribution";
200  CHECK_EQ(ntree % model.num_output_group, 0)
201  << "For multi-class classifiers with gradient boosted trees, the number "
202  << "of trees must be evenly divisible by the number of output groups";
203  }
204  } else if (flag_leaf_vector == 1) {
205  // multiclass classification with a random forest
206  CHECK(model.random_forest_flag)
207  << "In multi-class classifiers with gradient boosted trees, each leaf "
208  << "node must output a single floating-point value.";
209  } else {
210  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
211  }
212  return model;
213 }
214 
215 void ExportProtobufModel(const char* filename, const Model& model) {
216  GOOGLE_PROTOBUF_VERIFY_VERSION;
217 
218  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "w"));
219  dmlc::ostream os(fi.get());
220  treelite_protobuf::Model protomodel;
221 
222  protomodel.set_num_feature(
223  static_cast<google::protobuf::int32>(model.num_feature));
224 
225  protomodel.set_num_output_group(
226  static_cast<google::protobuf::int32>(model.num_output_group));
227 
228  protomodel.set_random_forest_flag(model.random_forest_flag);
229 
230  // extra parameters field
231  for (const auto& kv : model.param.__DICT__()) {
232  (*protomodel.mutable_extra_params())[kv.first] = kv.second;
233  }
234 
235  // flag to check consistent use of leaf vector
236  // 0: no leaf should use leaf vector
237  // 1: every leaf should use leaf vector
238  // -1: indeterminate
239  int8_t flag_leaf_vector = -1;
240 
241  const int ntree = model.trees.size();
242  for (int i = 0; i < ntree; ++i) {
243  const Tree& tree = model.trees[i];
244  treelite_protobuf::Tree* proto_tree = protomodel.add_trees();
245 
246  std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
247  Q.push({0, proto_tree->mutable_head()});
248  while (!Q.empty()) {
249  auto elem = Q.front(); Q.pop();
250  const int nid = elem.first;
251  treelite_protobuf::Node* proto_node = elem.second;
252  if (tree[nid].is_leaf()) { // leaf node
253  if (tree[nid].has_leaf_vector()) { // leaf node with vector output
254  CHECK(flag_leaf_vector != 0)
255  << "Inconsistent use of leaf vector: if one leaf node uses "
256  << "a leaf vector, *every* leaf node must use a leaf vector as well";
257  flag_leaf_vector = 1; // now every leaf must use leaf vector
258 
259  const auto& leaf_vector = tree[nid].leaf_vector();
260  CHECK_EQ(leaf_vector.size(), model.num_output_group)
261  << "The length of leaf vector must be identical to the "
262  << "number of output groups";
263  for (tl_float e : leaf_vector) {
264  proto_node->add_leaf_vector(static_cast<double>(e));
265  }
266  CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
267  } else { // leaf node with scalar output
268  CHECK(flag_leaf_vector != 1)
269  << "Inconsistent use of leaf vector: if one leaf node does not use"
270  << "a leaf vector, *no other* leaf node can use a leaf vector";
271  flag_leaf_vector = 0; // now no leaf can use leaf vector
272 
273  proto_node->set_leaf_value(static_cast<double>(tree[nid].leaf_value()));
274  }
275  } else if (tree[nid].split_type() == SplitFeatureType::kNumerical) {
276  // numerical split
277  const unsigned split_index = tree[nid].split_index();
278  const tl_float threshold = tree[nid].threshold();
279  const bool default_left = tree[nid].default_left();
280  const Operator op = tree[nid].comparison_op();
281 
282  proto_node->set_default_left(default_left);
283  proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
284  proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
285  proto_node->set_op(OpName(op));
286  proto_node->set_threshold(static_cast<double>(threshold));
287  Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
288  Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
289  } else { // categorical split
290  const unsigned split_index = tree[nid].split_index();
291  const auto& left_categories = tree[nid].left_categories();
292  const bool default_left = tree[nid].default_left();
293 
294  proto_node->set_default_left(default_left);
295  proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
296  proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
297  for (auto e : left_categories) {
298  proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
299  }
300  Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
301  Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
302  }
303  /* set node statistics */
304  if (tree[nid].has_data_count()) {
305  proto_node->set_data_count(
306  static_cast<google::protobuf::uint64>(tree[nid].data_count()));
307  }
308  if (tree[nid].has_sum_hess()) {
309  proto_node->set_sum_hess(tree[nid].sum_hess());
310  }
311  if (tree[nid].has_gain()) {
312  proto_node->set_gain(tree[nid].gain());
313  }
314  }
315  }
316  CHECK(protomodel.SerializeToOstream(&os))
317  << "Failed to write Protocol Buffers file";
318  os.set_stream(nullptr);
319 }
320 
321 } // namespace frontend
322 } // namespace treelite
323 
324 #else // TREELITE_PROTOBUF_SUPPORT
325 
326 namespace treelite {
327 namespace frontend {
328 
329 DMLC_REGISTRY_FILE_TAG(protobuf);
330 
331 Model LoadProtobufModel(const char* filename) {
332  LOG(FATAL) << "Treelite was not compiled with Protobuf!";
333  return Model(); // should not reach here
334 }
335 
336 void ExportProtobufModel(const char* filename, const Model& model) {
337  LOG(FATAL) << "Treelite was not compiled with Protobuf!";
338 }
339 
340 } // namespace frontend
341 } // namespace treelite
342 
343 #endif // TREELITE_PROTOBUF_SUPPORT
void ExportProtobufModel(const char *filename, const Model &model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
Definition: protobuf.cc:336
thin wrapper for tree ensemble model
Definition: tree.h:415
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
model structure for tree
Model LoadProtobufModel(const char *filename)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: protobuf.cc:331
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Definition: optable.cc:12
double tl_float
float type to be used internally
Definition: base.h:17
Operator
comparison operators
Definition: base.h:23