8 #include <unordered_map> 10 #include <dmlc/data.h> 23 DMLC_REGISTRY_FILE_TAG(lightgbm);
26 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
27 return ParseStream(fi.get());
36 enum Masks : uint8_t {
41 enum class MissingType : uint8_t {
50 std::vector<double> leaf_value;
51 std::vector<int8_t> decision_type;
52 std::vector<uint64_t> cat_boundaries;
53 std::vector<uint32_t> cat_threshold;
54 std::vector<int> split_feature;
55 std::vector<double> threshold;
56 std::vector<int> left_child;
57 std::vector<int> right_child;
58 std::vector<float> split_gain;
59 std::vector<int> internal_count;
60 std::vector<int> leaf_count;
63 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
64 return (decision_type & mask) > 0;
67 inline MissingType GetMissingType(int8_t decision_type) {
68 return static_cast<MissingType
>((decision_type >> 2) & 3);
71 inline std::vector<uint32_t> BitsetToList(
const uint32_t* bits,
73 std::vector<uint32_t> result;
74 const size_t nbits = nslots * 32;
75 for (
size_t i = 0; i < nbits; ++i) {
76 const size_t i1 = i / 32;
77 const uint32_t i2 =
static_cast<uint32_t
>(i % 32);
78 if ((bits[i1] >> i2) & 1) {
79 result.push_back(static_cast<uint32_t>(i));
85 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
86 const size_t bufsize = 16 * 1024 * 1024;
87 std::vector<char> buf(bufsize);
89 std::vector<std::string> lines;
93 std::string leftover =
"";
94 while ((byte_read = fi->Read(&buf[0],
sizeof(
char) * bufsize)) > 0) {
97 while (i < byte_read) {
98 if (buf[i] ==
'\n' || buf[i] ==
'\r') {
99 if (tok_begin == 0 && leftover.length() + i > 0) {
101 lines.push_back(leftover + std::string(&buf[0], i));
104 lines.emplace_back(&buf[tok_begin], i - tok_begin);
107 for (; (buf[i] ==
'\n' || buf[i] ==
'\r') && i < byte_read; ++i) {}
114 leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
117 if (!leftover.empty()) {
119 <<
"Warning: input file was not terminated with end-of-line character.";
120 lines.push_back(leftover);
127 std::vector<LGBTree> lgb_trees_;
128 int max_feature_idx_;
129 int num_tree_per_iteration_;
130 bool average_output_;
131 std::string obj_name_;
132 std::vector<std::string> obj_param_;
135 std::vector<std::string> lines = LoadText(fi);
136 std::unordered_map<std::string, std::string> global_dict;
137 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
139 bool in_tree =
false;
140 for (
const auto& line : lines) {
141 std::istringstream ss(line);
142 std::string key, value, rest;
143 std::getline(ss, key,
'=');
144 std::getline(ss, value,
'=');
145 std::getline(ss, rest);
146 CHECK(rest.empty()) <<
"Ill-formed LightGBM model file";
149 tree_dict.emplace_back();
152 tree_dict.back()[key] = value;
154 global_dict[key] = value;
160 auto it = global_dict.find(
"objective");
161 CHECK(it != global_dict.end())
162 <<
"Ill-formed LightGBM model file: need objective";
163 auto obj_strs = treelite::common::Split(it->second,
' ');
164 obj_name_ = obj_strs[0];
165 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
167 it = global_dict.find(
"max_feature_idx");
168 CHECK(it != global_dict.end())
169 <<
"Ill-formed LightGBM model file: need max_feature_idx";
170 max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
171 it = global_dict.find(
"num_tree_per_iteration");
172 CHECK(it != global_dict.end())
173 <<
"Ill-formed LightGBM model file: need num_tree_per_iteration";
174 num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
176 it = global_dict.find(
"average_output");
177 average_output_ = (it != global_dict.end());
180 for (
const auto& dict : tree_dict) {
181 lgb_trees_.emplace_back();
182 LGBTree& tree = lgb_trees_.back();
184 auto it = dict.find(
"num_leaves");
185 CHECK(it != dict.end())
186 <<
"Ill-formed LightGBM model file: need num_leaves";
187 tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
189 it = dict.find(
"num_cat");
190 CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
191 tree.num_cat = treelite::common::TextToNumber<int>(it->second);
193 it = dict.find(
"leaf_value");
194 CHECK(it != dict.end() && !it->second.empty())
195 <<
"Ill-formed LightGBM model file: need leaf_value";
197 = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
199 it = dict.find(
"decision_type");
200 if (it == dict.end()) {
201 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
203 CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
204 <<
"Ill-formed LightGBM model file: decision_type cannot be empty string";
206 = treelite::common::TextToArray<int8_t>(it->second,
207 tree.num_leaves - 1);
210 if (tree.num_cat > 0) {
211 it = dict.find(
"cat_boundaries");
212 CHECK(it != dict.end() && !it->second.empty())
213 <<
"Ill-formed LightGBM model file: need cat_boundaries";
215 = treelite::common::TextToArray<uint64_t>(it->second, tree.num_cat + 1);
216 it = dict.find(
"cat_threshold");
217 CHECK(it != dict.end() && !it->second.empty())
218 <<
"Ill-formed LightGBM model file: need cat_threshold";
220 = treelite::common::TextToArray<uint32_t>(it->second,
221 tree.cat_boundaries.back());
224 it = dict.find(
"split_feature");
225 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
226 <<
"Ill-formed LightGBM model file: need split_feature";
228 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
230 it = dict.find(
"threshold");
231 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
232 <<
"Ill-formed LightGBM model file: need threshold";
234 = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
236 it = dict.find(
"split_gain");
237 if (it != dict.end()) {
238 CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
239 <<
"Ill-formed LightGBM model file: split_gain cannot be empty string";
241 = treelite::common::TextToArray<float>(it->second, tree.num_leaves - 1);
243 tree.split_gain.resize(tree.num_leaves - 1);
246 it = dict.find(
"internal_count");
247 if (it != dict.end()) {
248 CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
249 <<
"Ill-formed LightGBM model file: internal_count cannot be empty string";
251 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
253 tree.internal_count.resize(tree.num_leaves - 1);
256 it = dict.find(
"leaf_count");
257 if (it != dict.end()) {
258 CHECK(!it->second.empty())
259 <<
"Ill-formed LightGBM model file: leaf_count cannot be empty string";
261 = treelite::common::TextToArray<int>(it->second, tree.num_leaves);
263 tree.leaf_count.resize(tree.num_leaves);
266 it = dict.find(
"left_child");
267 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
268 <<
"Ill-formed LightGBM model file: need left_child";
270 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
272 it = dict.find(
"right_child");
273 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
274 <<
"Ill-formed LightGBM model file: need right_child";
276 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
282 model.num_output_group = num_tree_per_iteration_;
283 if (model.num_output_group > 1) {
285 CHECK(!average_output_)
286 <<
"Ill-formed LightGBM model file: cannot use random forest mode " 287 <<
"for multi-class classification";
288 model.random_forest_flag =
false;
290 model.random_forest_flag = average_output_;
294 if (obj_name_ ==
"multiclass") {
298 for (
const auto& str : obj_param_) {
299 auto tokens = treelite::common::Split(str,
':');
300 if (tokens.size() == 2 && tokens[0] ==
"num_class" 301 && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
306 CHECK(num_class >= 0 && num_class == model.num_output_group)
307 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
309 model.param.pred_transform =
"softmax";
310 }
else if (obj_name_ ==
"multiclassova") {
316 for (
const auto& str : obj_param_) {
317 auto tokens = treelite::common::Split(str,
':');
318 if (tokens.size() == 2) {
319 if (tokens[0] ==
"num_class" 320 && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
322 }
else if (tokens[0] ==
"sigmoid" 323 && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
328 CHECK(num_class >= 0 && num_class == model.num_output_group
330 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
332 model.param.pred_transform =
"multiclass_ova";
333 model.param.sigmoid_alpha = alpha;
334 }
else if (obj_name_ ==
"binary") {
338 for (
const auto& str : obj_param_) {
339 auto tokens = treelite::common::Split(str,
':');
340 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 341 && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
346 CHECK_GT(alpha, 0.0f)
347 <<
"Ill-formed LightGBM model file: not a valid binary objective";
349 model.param.pred_transform =
"sigmoid";
350 model.param.sigmoid_alpha = alpha;
351 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
352 model.param.pred_transform =
"sigmoid";
353 model.param.sigmoid_alpha = 1.0f;
354 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
355 model.param.pred_transform =
"logarithm_one_plus_exp";
357 model.param.pred_transform =
"identity";
361 for (
const auto& lgb_tree : lgb_trees_) {
362 model.trees.emplace_back();
368 std::queue<std::pair<int, int>> Q;
372 std::tie(old_id, new_id) = Q.front(); Q.pop();
374 const double leaf_value = lgb_tree.leaf_value[~old_id];
375 const int data_count = lgb_tree.leaf_count[~old_id];
376 tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
377 CHECK_GE(data_count, 0);
378 tree[new_id].set_data_count(static_cast<size_t>(data_count));
380 const int data_count = lgb_tree.internal_count[old_id];
381 const unsigned split_index =
382 static_cast<unsigned>(lgb_tree.split_feature[old_id]);
385 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
387 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
388 const std::vector<uint32_t> left_categories
389 = BitsetToList(lgb_tree.cat_threshold.data()
390 + lgb_tree.cat_boundaries[cat_idx],
391 lgb_tree.cat_boundaries[cat_idx + 1]
392 - lgb_tree.cat_boundaries[cat_idx]);
393 const auto missing_type
394 = GetMissingType(lgb_tree.decision_type[old_id]);
395 tree[new_id].set_categorical_split(split_index,
false,
396 (missing_type != MissingType::kNaN),
402 const bool default_left
403 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
405 tree[new_id].set_numerical_split(split_index, threshold,
406 default_left, cmp_op);
408 CHECK_GE(data_count, 0);
409 tree[new_id].set_data_count(static_cast<size_t>(data_count));
410 tree[new_id].set_gain(static_cast<double>(lgb_tree.split_gain[old_id]));
411 Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
412 Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
416 LOG(INFO) <<
"model.num_tree = " << model.trees.size();
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
in-memory representation of a decision tree
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
double tl_float
float type to be used internally
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators