11 #ifdef PROTOBUF_SUPPORT 15 enum class NodeType : int8_t {
16 kLeaf, kLeafVector, kNumericalSplit, kCategoricalSplit
19 inline NodeType GetNodeType(
const treelite_protobuf::Node& node) {
20 if (node.has_left_child()) {
21 CHECK(node.has_right_child());
22 CHECK(node.has_default_left());
23 CHECK(node.has_split_index());
24 CHECK(node.has_split_type());
25 CHECK(!node.has_leaf_value());
26 CHECK_EQ(node.leaf_vector_size(), 0);
27 const auto split_type = node.split_type();
28 if (split_type == treelite_protobuf::Node_SplitFeatureType_NUMERICAL) {
31 CHECK(node.has_threshold());
32 CHECK_EQ(node.left_categories_size(), 0);
33 return NodeType::kNumericalSplit;
35 CHECK(!node.has_op());
36 CHECK(!node.has_threshold());
37 return NodeType::kCategoricalSplit;
40 CHECK(!node.has_right_child());
41 CHECK(!node.has_default_left());
42 CHECK(!node.has_split_index());
43 CHECK(!node.has_split_type());
44 CHECK(!node.has_op());
45 CHECK(!node.has_threshold());
46 CHECK_EQ(node.left_categories_size(), 0);
47 if (node.has_leaf_value()) {
48 CHECK_EQ(node.leaf_vector_size(), 0);
49 return NodeType::kLeaf;
51 CHECK_GT(node.leaf_vector_size(), 0);
52 return NodeType::kLeafVector;
62 DMLC_REGISTRY_FILE_TAG(protobuf);
65 GOOGLE_PROTOBUF_VERIFY_VERSION;
67 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
68 dmlc::istream is(fi.get());
69 treelite_protobuf::Model protomodel;
70 CHECK(protomodel.ParseFromIstream(&is)) <<
"Ill-formed Protocol Buffers file";
73 CHECK(protomodel.has_num_feature()) <<
"num_feature must exist";
74 const auto num_feature = protomodel.num_feature();
75 CHECK_LT(num_feature, std::numeric_limits<int>::max())
76 <<
"num_feature too big";
77 CHECK_GT(num_feature, 0) <<
"num_feature must be positive";
78 model.num_feature =
static_cast<int>(protomodel.num_feature());
80 CHECK(protomodel.has_num_output_group()) <<
"num_output_group must exist";
81 const auto num_output_group = protomodel.num_output_group();
82 CHECK_LT(num_output_group, std::numeric_limits<int>::max())
83 <<
"num_output_group too big";
84 CHECK_GT(num_output_group, 0) <<
"num_output_group must be positive";
85 model.num_output_group =
static_cast<int>(protomodel.num_output_group());
87 CHECK(protomodel.has_random_forest_flag())
88 <<
"random_forest_flag must exist";
89 model.random_forest_flag = protomodel.random_forest_flag();
92 const auto& ep = protomodel.extra_params();
93 std::vector<std::pair<std::string, std::string>> cfg;
94 std::copy(ep.begin(), ep.end(), std::back_inserter(cfg));
95 InitParamAndCheck(&model.param, cfg);
101 int8_t flag_leaf_vector = -1;
103 const int ntree = protomodel.trees_size();
104 for (
int i = 0; i < ntree; ++i) {
105 model.trees.emplace_back();
106 Tree& tree = model.trees.back();
109 CHECK(protomodel.trees(i).has_head());
112 std::queue<std::pair<const treelite_protobuf::Node&, int>> Q;
114 Q.push({protomodel.trees(i).head(), 0});
116 auto elem = Q.front(); Q.pop();
117 const treelite_protobuf::Node& node = elem.first;
118 int id = elem.second;
119 const NodeType node_type = GetNodeType(node);
120 if (node_type == NodeType::kLeaf) {
121 CHECK(flag_leaf_vector != 1)
122 <<
"Inconsistent use of leaf vector: if one leaf node does not use" 123 <<
"a leaf vector, *no other* leaf node can use a leaf vector";
124 flag_leaf_vector = 0;
126 tree[id].set_leaf(static_cast<tl_float>(node.leaf_value()));
127 }
else if (node_type == NodeType::kLeafVector) {
128 CHECK(flag_leaf_vector != 0)
129 <<
"Inconsistent use of leaf vector: if one leaf node uses " 130 <<
"a leaf vector, *every* leaf node must use a leaf vector as well";
131 flag_leaf_vector = 1;
133 const int len = node.leaf_vector_size();
134 CHECK_EQ(len, model.num_output_group)
135 <<
"The length of leaf vector must be identical to the " 136 <<
"number of output groups";
137 std::vector<tl_float> leaf_vector(len);
138 for (
int i = 0; i < len; ++i) {
139 leaf_vector[i] =
static_cast<tl_float>(node.leaf_vector(i));
141 tree[id].set_leaf_vector(leaf_vector);
142 }
else if (node_type == NodeType::kNumericalSplit) {
143 const auto split_index = node.split_index();
144 const std::string opname = node.op();
145 CHECK_LT(split_index, model.num_feature)
146 <<
"split_index must be between 0 and [num_feature] - 1.";
147 CHECK_GE(split_index, 0) <<
"split_index must be positive.";
148 CHECK_GT(
optable.count(opname), 0) <<
"No operator `" 149 << opname <<
"\" exists";
151 tree[id].set_numerical_split(static_cast<unsigned>(split_index),
152 static_cast<tl_float>(node.threshold()),
155 Q.push({node.left_child(), tree[id].cleft()});
156 Q.push({node.right_child(), tree[id].cright()});
158 const auto split_index = node.split_index();
159 CHECK_LT(split_index, model.num_feature)
160 <<
"split_index must be between 0 and [num_feature] - 1.";
161 CHECK_GE(split_index, 0) <<
"split_index must be positive.";
162 const int left_categories_size = node.left_categories_size();
163 std::vector<uint8_t> left_categories;
164 for (
int i = 0; i < left_categories_size; ++i) {
165 const auto cat = node.left_categories(i);
166 CHECK(cat <= std::numeric_limits<uint8_t>::max());
167 left_categories.push_back(static_cast<uint8_t>(cat));
170 tree[id].set_categorical_split(static_cast<unsigned>(split_index),
173 Q.push({node.left_child(), tree[id].cleft()});
174 Q.push({node.right_child(), tree[id].cright()});
178 if (flag_leaf_vector == 0) {
179 if (model.num_output_group > 1) {
181 CHECK(!model.random_forest_flag)
182 <<
"To use a random forest for multi-class classification, each leaf " 183 <<
"node must output a leaf vector specifying a probability " 185 CHECK_EQ(ntree % model.num_output_group, 0)
186 <<
"For multi-class classifiers with gradient boosted trees, the number " 187 <<
"of trees must be evenly divisible by the number of output groups";
189 }
else if (flag_leaf_vector == 1) {
191 CHECK(model.random_forest_flag)
192 <<
"In multi-class classifiers with gradient boosted trees, each leaf " 193 <<
"node must output a single floating-point value.";
195 LOG(FATAL) <<
"Impossible thing happened: model has no leaf node!";
208 DMLC_REGISTRY_FILE_TAG(protobuf);
211 LOG(FATAL) <<
"Protobuf library not linked";
218 #endif // PROTOBUF_SUPPORT thin wrapper for tree ensemble model
float tl_float
float type to be used internally
Model LoadProtobufModel(const char *filename)
load a model in Protocol Buffers format. Protocol Buffers (google/protobuf) is a language- and platfo...
const std::unordered_map< std::string, Operator > optable
conversion table from string to operator, defined in optable.cc