10 #include <unordered_map> 22 DMLC_REGISTRY_FILE_TAG(lightgbm);
25 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
26 return ParseStream(fi.get());
35 enum Masks : uint8_t {
43 std::vector<double> leaf_value;
44 std::vector<int8_t> decision_type;
45 std::vector<int> cat_boundaries;
46 std::vector<uint32_t> cat_threshold;
47 std::vector<int> split_feature;
48 std::vector<double> threshold;
49 std::vector<int> left_child;
50 std::vector<int> right_child;
53 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
54 return (decision_type & mask) > 0;
57 inline std::vector<uint8_t> BitsetToList(
const uint32_t* bits,
59 std::vector<uint8_t> result;
60 CHECK(nslots == 1 || nslots == 2);
61 const uint8_t nbits = nslots * 32;
62 for (uint8_t i = 0; i < nbits; ++i) {
63 const uint8_t i1 = i / 32;
64 const uint8_t i2 = i % 32;
65 if ((bits[i1] >> i2) & 1) {
72 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
73 const size_t bufsize = 16 * 1024 * 1024;
74 std::vector<char> buf(bufsize);
76 std::vector<std::string> lines;
80 std::string leftover =
"";
81 while ( (byte_read = fi->Read(&buf[0],
sizeof(
char) * bufsize)) > 0) {
84 while (i < byte_read) {
85 if (buf[i] ==
'\n' || buf[i] ==
'\r') {
86 if (tok_begin == 0 && leftover.length() + i > 0) {
88 lines.push_back(leftover + std::string(&buf[0], i));
91 lines.emplace_back(&buf[tok_begin], i - tok_begin);
94 for (; (buf[i] ==
'\n' || buf[i] ==
'\r') && i < byte_read; ++i);
101 leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
104 if (!leftover.empty()) {
106 <<
"Warning: input file was not terminated with end-of-line character.";
107 lines.push_back(leftover);
114 std::vector<LGBTree> lgb_trees_;
115 int max_feature_idx_;
116 int num_tree_per_iteration_;
117 bool average_output_;
118 std::string obj_name_;
119 std::vector<std::string> obj_param_;
122 std::vector<std::string> lines = LoadText(fi);
123 std::unordered_map<std::string, std::string> global_dict;
124 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
126 bool in_tree =
false;
127 for (
const auto& line : lines) {
128 std::istringstream ss(line);
129 std::string key, value, rest;
130 std::getline(ss, key,
'=');
131 std::getline(ss, value,
'=');
132 std::getline(ss, rest);
133 CHECK(rest.empty()) <<
"Ill-formed LightGBM model file";
136 tree_dict.emplace_back();
139 tree_dict.back()[key] = value;
141 global_dict[key] = value;
147 auto it = global_dict.find(
"objective");
148 CHECK(it != global_dict.end())
149 <<
"Ill-formed LightGBM model file: need objective";
150 auto obj_strs = treelite::common::Split(it->second,
' ');
151 obj_name_ = obj_strs[0];
152 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
154 it = global_dict.find(
"max_feature_idx");
155 CHECK(it != global_dict.end())
156 <<
"Ill-formed LightGBM model file: need max_feature_idx";
157 max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
158 it = global_dict.find(
"num_tree_per_iteration");
159 CHECK(it != global_dict.end())
160 <<
"Ill-formed LightGBM model file: need num_tree_per_iteration";
161 num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
163 it = global_dict.find(
"average_output");
164 average_output_ = (it != global_dict.end());
167 for (
const auto& dict : tree_dict) {
168 lgb_trees_.emplace_back();
169 LGBTree& tree = lgb_trees_.back();
171 auto it = dict.find(
"num_leaves");
172 CHECK(it != dict.end())
173 <<
"Ill-formed LightGBM model file: need num_leaves";
174 tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
176 it = dict.find(
"num_cat");
177 CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
178 tree.num_cat = treelite::common::TextToNumber<int>(it->second);
180 it = dict.find(
"leaf_value");
181 CHECK(it != dict.end())
182 <<
"Ill-formed LightGBM model file: need leaf_value";
184 = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
186 it = dict.find(
"decision_type");
187 CHECK(it != dict.end())
188 <<
"Ill-formed LightGBM model file: need decision_type";
189 if (it == dict.end()) {
190 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
193 = treelite::common::TextToArray<int8_t>(it->second,
194 tree.num_leaves - 1);
197 if (tree.num_cat > 0) {
198 it = dict.find(
"cat_boundaries");
199 CHECK(it != dict.end())
200 <<
"Ill-formed LightGBM model file: need cat_boundaries";
202 = treelite::common::TextToArray<int>(it->second, tree.num_cat + 1);
203 it = dict.find(
"cat_threshold");
204 CHECK(it != dict.end())
205 <<
"Ill-formed LightGBM model file: need cat_threshold";
207 = treelite::common::TextToArray<uint32_t>(it->second,
208 tree.cat_boundaries.back());
211 it = dict.find(
"split_feature");
212 CHECK(it != dict.end())
213 <<
"Ill-formed LightGBM model file: need split_feature";
215 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
217 it = dict.find(
"threshold");
218 CHECK(it != dict.end())
219 <<
"Ill-formed LightGBM model file: need threshold";
221 = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
223 it = dict.find(
"left_child");
224 CHECK(it != dict.end())
225 <<
"Ill-formed LightGBM model file: need left_child";
227 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
229 it = dict.find(
"right_child");
230 CHECK(it != dict.end())
231 <<
"Ill-formed LightGBM model file: need right_child";
233 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
239 model.num_output_group = num_tree_per_iteration_;
240 if (model.num_output_group > 1) {
242 CHECK(!average_output_)
243 <<
"Ill-formed LightGBM model file: cannot use random forest mode " 244 <<
"for multi-class classification";
245 model.random_forest_flag =
false;
247 model.random_forest_flag = average_output_;
251 if (obj_name_ ==
"multiclass") {
255 for (
const auto& str : obj_param_) {
256 auto tokens = treelite::common::Split(str,
':');
257 if (tokens.size() == 2 && tokens[0] ==
"num_class" 258 && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
263 CHECK(num_class >= 0 && num_class == model.num_output_group)
264 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
266 model.param.pred_transform =
"softmax";
267 }
else if (obj_name_ ==
"multiclassova") {
273 for (
const auto& str : obj_param_) {
274 auto tokens = treelite::common::Split(str,
':');
275 if (tokens.size() == 2) {
276 if (tokens[0] ==
"num_class" 277 && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
279 }
else if (tokens[0] ==
"sigmoid" 280 && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
285 CHECK(num_class >= 0 && num_class == model.num_output_group
287 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
289 model.param.pred_transform =
"multiclass_ova";
290 model.param.sigmoid_alpha = alpha;
291 }
else if (obj_name_ ==
"binary") {
295 for (
const auto& str : obj_param_) {
296 auto tokens = treelite::common::Split(str,
':');
297 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 298 && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
303 CHECK_GT(alpha, 0.0f)
304 <<
"Ill-formed LightGBM model file: not a valid binary objective";
306 model.param.pred_transform =
"sigmoid";
307 model.param.sigmoid_alpha = alpha;
308 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
309 model.param.pred_transform =
"sigmoid";
310 model.param.sigmoid_alpha = 1.0f;
311 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
312 model.param.pred_transform =
"logarithm_one_plus_exp";
314 model.param.pred_transform =
"identity";
318 for (
const auto& lgb_tree : lgb_trees_) {
319 model.trees.emplace_back();
325 std::queue<std::pair<int, int>> Q;
329 std::tie(old_id, new_id) = Q.front(); Q.pop();
331 const double leaf_value = lgb_tree.leaf_value[~old_id];
332 tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
334 const unsigned split_index =
335 static_cast<unsigned>(lgb_tree.split_feature[old_id]);
336 const bool default_left
337 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
339 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
341 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
342 CHECK_LE(lgb_tree.cat_boundaries[cat_idx + 1]
343 - lgb_tree.cat_boundaries[cat_idx], 2)
344 <<
"Categorical features must have 64 categories or fewer.";
345 const std::vector<uint8_t> left_categories
346 = BitsetToList(lgb_tree.cat_threshold.data()
347 + lgb_tree.cat_boundaries[cat_idx],
348 lgb_tree.cat_boundaries[cat_idx + 1]
349 - lgb_tree.cat_boundaries[cat_idx]);
350 tree[new_id].set_categorical_split(split_index, default_left,
357 tree[new_id].set_numerical_split(split_index, threshold,
358 default_left, cmp_op);
360 Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
361 Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
in-memory representation of a decision tree
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators