treelite
protobuf.cc
Go to the documentation of this file.
1 
8 #include <treelite/tree.h>
9 #include <queue>
10 
11 #ifdef PROTOBUF_SUPPORT
12 #include "tree.pb.h"
13 namespace {
14 
15 enum class NodeType : int8_t {
16  kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
17 };
18 
19 inline NodeType GetNodeType(const treelite_protobuf::Node& node) {
20  if (node.has_left_child()) { // node is non-leaf
21  CHECK(node.has_right_child());
22  CHECK(node.has_default_left());
23  CHECK(node.has_split_index());
24  CHECK(node.has_split_type());
25  CHECK(!node.has_leaf_value());
26  CHECK_EQ(node.leaf_vector_size(), 0);
27  const auto split_type = node.split_type();
28  if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
29  // numerical split
30  CHECK(node.has_op());
31  CHECK(node.has_threshold());
32  CHECK_EQ(node.left_categories_size(), 0);
33  return NodeType::kNumericalSplit;
34  } else { // categorical split
35  CHECK(!node.has_op());
36  CHECK(!node.has_threshold());
37  return NodeType::kCategoricalSplit;
38  }
39  } else { // node is leaf
40  CHECK(!node.has_right_child());
41  CHECK(!node.has_default_left());
42  CHECK(!node.has_split_index());
43  CHECK(!node.has_split_type());
44  CHECK(!node.has_op());
45  CHECK(!node.has_threshold());
46  CHECK_EQ(node.left_categories_size(), 0);
47  if (node.has_leaf_value()) {
48  CHECK_EQ(node.leaf_vector_size(), 0);
49  return NodeType::kLeaf;
50  } else {
51  CHECK_GT(node.leaf_vector_size(), 0);
52  return NodeType::kLeafVector;
53  }
54  }
55 }
56 
57 } // namespace anonymous
58 
59 namespace treelite {
60 namespace frontend {
61 
62 DMLC_REGISTRY_FILE_TAG(protobuf);
63 
64 Model LoadProtobufModel(const char* filename) {
65  GOOGLE_PROTOBUF_VERIFY_VERSION;
66 
67  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
68  dmlc::istream is(fi.get());
69  treelite_protobuf::Model protomodel;
70  CHECK(protomodel.ParseFromIstream(&is)) << "Ill-formed Protocol Buffers file";
71 
72  Model model;
73  CHECK(protomodel.has_num_feature()) << "num_feature must exist";
74  const auto num_feature = protomodel.num_feature();
75  CHECK_LT(num_feature, std::numeric_limits<int>::max())
76  << "num_feature too big";
77  CHECK_GT(num_feature, 0) << "num_feature must be positive";
78  model.num_feature = static_cast<int>(protomodel.num_feature());
79 
80  CHECK(protomodel.has_num_output_group()) << "num_output_group must exist";
81  const auto num_output_group = protomodel.num_output_group();
82  CHECK_LT(num_output_group, std::numeric_limits<int>::max())
83  << "num_output_group too big";
84  CHECK_GT(num_output_group, 0) << "num_output_group must be positive";
85  model.num_output_group = static_cast<int>(protomodel.num_output_group());
86 
87  CHECK(protomodel.has_random_forest_flag())
88  << "random_forest_flag must exist";
89  model.random_forest_flag = protomodel.random_forest_flag();
90 
91  // extra parameters field
92  const auto& ep = protomodel.extra_params();
93  std::vector<std::pair<std::string, std::string>> cfg;
94  std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
95  InitParamAndCheck(&model.param, cfg);
96 
97  // flag to check consistent use of leaf vector
98  // 0: no leaf should use leaf vector
99  // 1: every leaf should use leaf vector
100  // -1: indeterminate
101  int8_t flag_leaf_vector = -1;
102 
103  const int ntree = protomodel.trees_size();
104  for (int i = 0; i < ntree; ++i) {
105  model.trees.emplace_back();
106  Tree& tree = model.trees.back();
107  tree.Init();
108 
109  CHECK(protomodel.trees(i).has_head());
110  // assign node ID's so that a breadth-wise traversal would yield
111  // the monotonic sequence 0, 1, 2, ...
112  std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
113  // (proto node, ID)
114  Q.push({protomodel.trees(i).head(), 0});
115  while (!Q.empty()) {
116  auto elem = Q.front(); Q.pop();
117  const treelite_protobuf::Node& node = elem.first;
118  int id = elem.second;
119  const NodeType node_type = GetNodeType(node);
120  if (node_type == NodeType::kLeaf) {
121  CHECK(flag_leaf_vector != 1)
122  << "Inconsistent use of leaf vector: if one leaf node does not use"
123  << "a leaf vector, *no other* leaf node can use a leaf vector";
124  flag_leaf_vector = 0; // now no leaf can use leaf vector
125 
126  tree[id].set_leaf(static_cast<tl_float>(node.leaf_value()));
127  } else if (node_type == NodeType::kLeafVector) {
128  CHECK(flag_leaf_vector != 0)
129  << "Inconsistent use of leaf vector: if one leaf node uses "
130  << "a leaf vector, *every* leaf node must use a leaf vector as well";
131  flag_leaf_vector = 1; // now every leaf must use leaf vector
132 
133  const int len = node.leaf_vector_size();
134  CHECK_EQ(len, model.num_output_group)
135  << "The length of leaf vector must be identical to the "
136  << "number of output groups";
137  std::vector<tl_float> leaf_vector(len);
138  for (int i = 0; i < len; ++i) {
139  leaf_vector[i] = static_cast<tl_float>(node.leaf_vector(i));
140  }
141  tree[id].set_leaf_vector(leaf_vector);
142  } else if (node_type == NodeType::kNumericalSplit) {
143  const auto split_index = node.split_index();
144  const std::string opname = node.op();
145  CHECK_LT(split_index, model.num_feature)
146  << "split_index must be between 0 and [num_feature] - 1.";
147  CHECK_GE(split_index, 0) << "split_index must be positive.";
148  CHECK_GT(optable.count(opname), 0) << "No operator `"
149  << opname << "\" exists";
150  tree.AddChilds(id);
151  tree[id].set_numerical_split(static_cast<unsigned>(split_index),
152  static_cast<tl_float>(node.threshold()),
153  node.default_left(),
154  optable.at(opname.c_str()));
155  Q.push({node.left_child(), tree[id].cleft()});
156  Q.push({node.right_child(), tree[id].cright()});
157  } else { // categorical split
158  const auto split_index = node.split_index();
159  CHECK_LT(split_index, model.num_feature)
160  << "split_index must be between 0 and [num_feature] - 1.";
161  CHECK_GE(split_index, 0) << "split_index must be positive.";
162  const int left_categories_size = node.left_categories_size();
163  std::vector<uint8_t> left_categories;
164  for (int i = 0; i < left_categories_size; ++i) {
165  const auto cat = node.left_categories(i);
166  CHECK(cat <= std::numeric_limits<uint8_t>::max());
167  left_categories.push_back(static_cast<uint8_t>(cat));
168  }
169  tree.AddChilds(id);
170  tree[id].set_categorical_split(static_cast<unsigned>(split_index),
171  node.default_left(),
172  left_categories);
173  Q.push({node.left_child(), tree[id].cleft()});
174  Q.push({node.right_child(), tree[id].cright()});
175  }
176  }
177  }
178  if (flag_leaf_vector == 0) {
179  if (model.num_output_group > 1) {
180  // multiclass classification with gradient boosted trees
181  CHECK(!model.random_forest_flag)
182  << "To use a random forest for multi-class classification, each leaf "
183  << "node must output a leaf vector specifying a probability "
184  << "distribution";
185  CHECK_EQ(ntree % model.num_output_group, 0)
186  << "For multi-class classifiers with gradient boosted trees, the number "
187  << "of trees must be evenly divisible by the number of output groups";
188  }
189  } else if (flag_leaf_vector == 1) {
190  // multiclass classification with a random forest
191  CHECK(model.random_forest_flag)
192  << "In multi-class classifiers with gradient boosted trees, each leaf "
193  << "node must output a single floating-point value.";
194  } else {
195  LOG(FATAL) << "Impossible thing happened: model has no leaf node!";
196  }
197  return model;
198 }
199 
200 } // namespace frontend
201 } // namespace treelite
202 
203 #else
204 
205 namespace treelite {
206 namespace frontend {
207 
208 DMLC_REGISTRY_FILE_TAG(protobuf);
209 
210 Model LoadProtobufModel(const char* filename) {
211  LOG(FATAL) << "Protobuf library not linked";
212  return Model();
213 }
214 
215 } // namespace frontend
216 } // napespace treelite
217 
218 #endif // PROTOBUF_SUPPORT
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
model structure for tree
Model LoadProtobufModel(const char *filename)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
Definition: protobuf.cc:210
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc
Definition: optable.cc:12