treelite
protobuf.cc
Go to the documentation of this file.
1 
8 #include <dmlc/logging.h>
9 #include <treelite/tree.h>
10 #include <queue>
11 
12 #ifdef TREELITE_PROTOBUF_SUPPORT
13 
14 #include "tree.pb.h"
15 
16 namespace {
17 
18 enum class NodeType : int8_t {
19  kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
20 };
21 
22 inline NodeType GetNodeType(const treelite_protobuf::Node& node) {
23  if (node.has_left_child()) { // node is non-leaf
24  CHECK(node.has_right_child());
25  CHECK(node.has_default_left());
26  CHECK(node.has_split_index());
27  CHECK(node.has_split_type());
28  CHECK(!node.has_leaf_value());
29  CHECK_EQ(node.leaf_vector_size(), 0);
30  const auto split_type = node.split_type();
31  if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
32  // numerical split
33  CHECK(node.has_op());
34  CHECK(node.has_threshold());
35  CHECK_EQ(node.left_categories_size(), 0);
36  return NodeType::kNumericalSplit;
37  } else { // categorical split
38  CHECK(!node.has_op());
39  CHECK(!node.has_threshold());
40  CHECK(node.has_missing_category_to_zero());
41  return NodeType::kCategoricalSplit;
42  }
43  } else { // node is leaf
44  CHECK(!node.has_right_child());
45  CHECK(!node.has_default_left());
46  CHECK(!node.has_split_index());
47  CHECK(!node.has_split_type());
48  CHECK(!node.has_op());
49  CHECK(!node.has_threshold());
50  CHECK(!node.has_gain());
51  CHECK(!node.has_missing_category_to_zero());
52  CHECK_EQ(node.left_categories_size(), 0);
53  if (node.has_leaf_value()) {
54  CHECK_EQ(node.leaf_vector_size(), 0);
55  return NodeType::kLeaf;
56  } else {
57  CHECK_GT(node.leaf_vector_size(), 0);
58  return NodeType::kLeafVector;
59  }
60  }
61 }
62 
63 } // anonymous namespace
64 
65 namespace treelite {
66 namespace frontend {
67 
68 DMLC_REGISTRY_FILE_TAG(protobuf);
69 
70 Model LoadProtobufModel(const char* filename) {
71  GOOGLE_PROTOBUF_VERIFY_VERSION;
72 
73  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
74  dmlc::istream is(fi.get());
75  treelite_protobuf::Model protomodel;
76  CHECK(protomodel.ParseFromIstream(&is)) << "Ill-formed Protocol Buffers file";
77 
78  Model model;
79  CHECK(protomodel.has_num_feature()) << "num_feature must exist";
80  const auto num_feature = protomodel.num_feature();
81  CHECK_LT(num_feature, std::numeric_limits<int>::max())
82  << "num_feature too big";
83  CHECK_GT(num_feature, 0) << "num_feature must be positive";
84  model.num_feature = static_cast<int>(protomodel.num_feature());
85 
86  CHECK(protomodel.has_num_output_group()) << "num_output_group must exist";
87  const auto num_output_group = protomodel.num_output_group();
88  CHECK_LT(num_output_group, std::numeric_limits<int>::max())
89  << "num_output_group too big";
90  CHECK_GT(num_output_group, 0) << "num_output_group must be positive";
91  model.num_output_group = static_cast<int>(protomodel.num_output_group());
92 
93  CHECK(protomodel.has_random_forest_flag())
94  << "random_forest_flag must exist";
95  model.random_forest_flag = protomodel.random_forest_flag();
96 
97  // extra parameters field
98  const auto& ep = protomodel.extra_params();
99  std::vector<std::pair<std::string, std::string>> cfg;
100  std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
101  InitParamAndCheck(&model.param, cfg);
102 
103  // flag to check consistent use of leaf vector
104  // 0: no leaf should use leaf vector
105  // 1: every leaf should use leaf vector
106  // -1: indeterminate
107  int8_t flag_leaf_vector = -1;
108 
109  const int ntree = protomodel.trees_size();
110  for (int i = 0; i < ntree; ++i) {
111  model.trees.emplace_back();
112  Tree& tree = model.trees.back();
113  tree.Init();
114 
115  CHECK(protomodel.trees(i).has_head());
116  // assign node ID's so that a breadth-wise traversal would yield
117  // the monotonic sequence 0, 1, 2, ...
118  std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
119  // (proto node, ID)
120  Q.push({protomodel.trees(i).head(), 0});
121  while (!Q.empty()) {
122  auto elem = Q.front(); Q.pop();
123  const treelite_protobuf::Node& node = elem.first;
124  int id = elem.second;
125  const NodeType node_type = GetNodeType(node);
126  if (node_type == NodeType::kLeaf) { // leaf node with a scalar output
127  CHECK(flag_leaf_vector != 1)
128  << "Inconsistent use of leaf vector: if one leaf node does not use"
129  << "a leaf vector, *no other* leaf node can use a leaf vector";
130  flag_leaf_vector = 0; // now no leaf can use leaf vector
131 
132  tree[id].set_leaf(static_cast<tl_float>(node.leaf_value()));
133  } else if (node_type == NodeType::kLeafVector) {
134  // leaf node with vector output
135  CHECK(flag_leaf_vector != 0)
136  << "Inconsistent use of leaf vector: if one leaf node uses "
137  << "a leaf vector, *every* leaf node must use a leaf vector as well";
138  flag_leaf_vector = 1; // now every leaf must use leaf vector
139 
140  const int len = node.leaf_vector_size();
141  CHECK_EQ(len, model.num_output_group)
142  << "The length of leaf vector must be identical to the "
143  << "number of output groups";
144  std::vector<tl_float> leaf_vector(len);
145  for (int i = 0; i < len; ++i) {
146  leaf_vector[i] = static_cast<tl_float>(node.leaf_vector(i));
147  }
148  tree[id].set_leaf_vector(leaf_vector);
149  } else if (node_type == NodeType::kNumericalSplit) { // numerical split
150  const auto split_index = node.split_index();
151  const std::string opname = node.op();
152  CHECK_LT(split_index, model.num_feature)
153  << "split_index must be between 0 and [num_feature] - 1.";
154  CHECK_GE(split_index, 0) << "split_index must be positive.";
155  CHECK_GT(optable.count(opname), 0) << "No operator `"
156  << opname << "\" exists";
157  tree.AddChilds(id);
158  tree[id].set_numerical_split(static_cast<unsigned>(split_index),
159  static_cast<tl_float>(node.threshold()),
160  node.default_left(),
161  optable.at(opname.c_str()));
162  Q.push({node.left_child(), tree[id].cleft()});
163  Q.push({node.right_child(), tree[id].cright()});
164  } else { // categorical split
165  const auto split_index = node.split_index();
166  CHECK_LT(split_index, model.num_feature)
167  << "split_index must be between 0 and [num_feature] - 1.";
168  CHECK_GE(split_index, 0) << "split_index must be positive.";
169  const int left_categories_size = node.left_categories_size();
170  std::vector<uint32_t> left_categories;
171  for (int i = 0; i < left_categories_size; ++i) {
172  const auto cat = node.left_categories(i);
173  CHECK(cat <= std::numeric_limits<uint32_t>::max());
174  left_categories.push_back(static_cast<uint32_t>(cat));
175  }
176  tree.AddChilds(id);
177  tree[id].set_categorical_split(static_cast<unsigned>(split_index),
178  node.default_left(),
179  node.missing_category_to_zero(),
180  left_categories);
181  Q.push({node.left_child(), tree[id].cleft()});
182  Q.push({node.right_child(), tree[id].cright()});
183  }
184  /* set node statistics */
185  if (node.has_data_count()) {
186  tree[id].set_data_count(static_cast<size_t>(node.data_count()));
187  }
188  if (node.has_sum_hess()) {
189  tree[id].set_sum_hess(node.sum_hess());
190  }
191  if (node.has_gain()) {
192  tree[id].set_gain(node.gain());
193  }
194  }
195  }
196  if (flag_leaf_vector == 0) {
197  if (model.num_output_group > 1) {
198  // multiclass classification with gradient boosted trees
199  CHECK(!model.random_forest_flag)
200  << "To use a random forest for multi-class classification, each leaf "
201  << "node must output a leaf vector specifying a probability "
202  << "distribution";
203  CHECK_EQ(ntree % model.num_output_group, 0)
204  << "For multi-class classifiers with gradient boosted trees, the number "
205  << "of trees must be evenly divisible by the number of output groups";
206  }
207  } else if (flag_leaf_vector == 1) {
208  // multiclass classification with a random forest
209  CHECK(model.random_forest_flag)
210  << "In multi-class classifiers with gradient boosted trees, each leaf "
211  << "node must output a single floating-point value.";
212  } else {
213  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
214  }
215  return model;
216 }
217 
218 void ExportProtobufModel(const char* filename, const Model& model) {
219  GOOGLE_PROTOBUF_VERIFY_VERSION;
220 
221  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "w"));
222  dmlc::ostream os(fi.get());
223  treelite_protobuf::Model protomodel;
224 
225  protomodel.set_num_feature(
226  static_cast<google::protobuf::int32>(model.num_feature));
227 
228  protomodel.set_num_output_group(
229  static_cast<google::protobuf::int32>(model.num_output_group));
230 
231  protomodel.set_random_forest_flag(model.random_forest_flag);
232 
233  // extra parameters field
234  for (const auto& kv : model.param.__DICT__()) {
235  (*protomodel.mutable_extra_params())[kv.first] = kv.second;
236  }
237 
238  // flag to check consistent use of leaf vector
239  // 0: no leaf should use leaf vector
240  // 1: every leaf should use leaf vector
241  // -1: indeterminate
242  int8_t flag_leaf_vector = -1;
243 
244  const int ntree = model.trees.size();
245  for (int i = 0; i < ntree; ++i) {
246  const Tree& tree = model.trees[i];
247  treelite_protobuf::Tree* proto_tree = protomodel.add_trees();
248 
249  std::queue<std::pair<int, treelite_protobuf::Node*>> Q;
250  Q.push({0, proto_tree->mutable_head()});
251  while (!Q.empty()) {
252  auto elem = Q.front(); Q.pop();
253  const int nid = elem.first;
254  treelite_protobuf::Node* proto_node = elem.second;
255  if (tree[nid].is_leaf()) { // leaf node
256  if (tree[nid].has_leaf_vector()) { // leaf node with vector output
257  CHECK(flag_leaf_vector != 0)
258  << "Inconsistent use of leaf vector: if one leaf node uses "
259  << "a leaf vector, *every* leaf node must use a leaf vector as well";
260  flag_leaf_vector = 1; // now every leaf must use leaf vector
261 
262  const auto& leaf_vector = tree[nid].leaf_vector();
263  CHECK_EQ(leaf_vector.size(), model.num_output_group)
264  << "The length of leaf vector must be identical to the "
265  << "number of output groups";
266  for (tl_float e : leaf_vector) {
267  proto_node->add_leaf_vector(static_cast<double>(e));
268  }
269  CHECK_EQ(proto_node->leaf_vector_size(), leaf_vector.size());
270  } else { // leaf node with scalar output
271  CHECK(flag_leaf_vector != 1)
272  << "Inconsistent use of leaf vector: if one leaf node does not use"
273  << "a leaf vector, *no other* leaf node can use a leaf vector";
274  flag_leaf_vector = 0; // now no leaf can use leaf vector
275 
276  proto_node->set_leaf_value(static_cast<double>(tree[nid].leaf_value()));
277  }
278  } else if (tree[nid].split_type() == SplitFeatureType::kNumerical) {
279  // numerical split
280  const unsigned split_index = tree[nid].split_index();
281  const tl_float threshold = tree[nid].threshold();
282  const bool default_left = tree[nid].default_left();
283  const Operator op = tree[nid].comparison_op();
284 
285  proto_node->set_default_left(default_left);
286  proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
287  proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_NUMERICAL);
288  proto_node->set_op(OpName(op));
289  proto_node->set_threshold(static_cast<double>(threshold));
290  Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
291  Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
292  } else { // categorical split
293  const unsigned split_index = tree[nid].split_index();
294  const auto& left_categories = tree[nid].left_categories();
295  const bool default_left = tree[nid].default_left();
296  const bool missing_category_to_zero = tree[nid].missing_category_to_zero();
297 
298  proto_node->set_default_left(default_left);
299  proto_node->set_split_index(static_cast<google::protobuf::int32>(split_index));
300  proto_node->set_split_type(treelite_protobuf::Node_SplitFeatureType_CATEGORICAL);
301  proto_node->set_missing_category_to_zero(missing_category_to_zero);
302  for (auto e : left_categories) {
303  proto_node->add_left_categories(static_cast<google::protobuf::uint32>(e));
304  }
305  Q.push({tree[nid].cleft(), proto_node->mutable_left_child()});
306  Q.push({tree[nid].cright(), proto_node->mutable_right_child()});
307  }
308  /* set node statistics */
309  if (tree[nid].has_data_count()) {
310  proto_node->set_data_count(
311  static_cast<google::protobuf::uint64>(tree[nid].data_count()));
312  }
313  if (tree[nid].has_sum_hess()) {
314  proto_node->set_sum_hess(tree[nid].sum_hess());
315  }
316  if (tree[nid].has_gain()) {
317  proto_node->set_gain(tree[nid].gain());
318  }
319  }
320  }
321  CHECK(protomodel.SerializeToOstream(&os))
322  << "Failed to write Protocol Buffers file";
323  os.set_stream(nullptr);
324 }
325 
326 } // namespace frontend
327 } // namespace treelite
328 
329 #else // TREELITE_PROTOBUF_SUPPORT
330 
331 namespace treelite {
332 namespace frontend {
333 
334 DMLC_REGISTRY_FILE_TAG(protobuf);
335 
336 Model LoadProtobufModel(const char* filename) {
337  LOG(FATAL) << "Treelite was not compiled with Protobuf!";
338  return Model(); // should not reach here
339 }
340 
341 void ExportProtobufModel(const char* filename, const Model& model) {
342  LOG(FATAL) << "Treelite was not compiled with Protobuf!";
343 }
344 
345 } // namespace frontend
346 } // namespace treelite
347 
348 #endif // TREELITE_PROTOBUF_SUPPORT
void ExportProtobufModel(const char *filename, const Model &model)
export a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and plat...
Definition: protobuf.cc:341
thin wrapper for tree ensemble model
Definition: tree.h:427
std::string OpName(Operator op)
get string representation of comparsion operator
Definition: base.h:38
model structure for tree
Model LoadProtobufModel(const char *filename)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: protobuf.cc:336
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Definition: optable.cc:12
double tl_float
float type to be used internally
Definition: base.h:17
Operator
comparison operators
Definition: base.h:23