10 #include <unordered_map> 22 DMLC_REGISTRY_FILE_TAG(lightgbm);
25 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
26 return ParseStream(fi.get());
35 enum Masks : uint8_t {
43 std::vector<double> leaf_value;
44 std::vector<int8_t> decision_type;
45 std::vector<int> cat_boundaries;
46 std::vector<uint32_t> cat_threshold;
47 std::vector<int> split_feature;
48 std::vector<double> threshold;
49 std::vector<int> left_child;
50 std::vector<int> right_child;
53 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
54 return (decision_type & mask) > 0;
57 inline std::vector<uint32_t> BitsetToList(
const uint32_t* bits,
59 std::vector<uint32_t> result;
60 const uint32_t nbits =
static_cast<uint32_t
>(nslots) * 32;
61 for (uint32_t i = 0; i < nbits; ++i) {
62 const uint32_t i1 = i / 32;
63 const uint32_t i2 = i % 32;
64 if ((bits[i1] >> i2) & 1) {
71 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
72 const size_t bufsize = 16 * 1024 * 1024;
73 std::vector<char> buf(bufsize);
75 std::vector<std::string> lines;
79 std::string leftover =
"";
80 while ( (byte_read = fi->Read(&buf[0],
sizeof(
char) * bufsize)) > 0) {
83 while (i < byte_read) {
84 if (buf[i] ==
'\n' || buf[i] ==
'\r') {
85 if (tok_begin == 0 && leftover.length() + i > 0) {
87 lines.push_back(leftover + std::string(&buf[0], i));
90 lines.emplace_back(&buf[tok_begin], i - tok_begin);
93 for (; (buf[i] ==
'\n' || buf[i] ==
'\r') && i < byte_read; ++i);
100 leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
103 if (!leftover.empty()) {
105 <<
"Warning: input file was not terminated with end-of-line character.";
106 lines.push_back(leftover);
113 std::vector<LGBTree> lgb_trees_;
114 int max_feature_idx_;
115 int num_tree_per_iteration_;
116 bool average_output_;
117 std::string obj_name_;
118 std::vector<std::string> obj_param_;
121 std::vector<std::string> lines = LoadText(fi);
122 std::unordered_map<std::string, std::string> global_dict;
123 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
125 bool in_tree =
false;
126 for (
const auto& line : lines) {
127 std::istringstream ss(line);
128 std::string key, value, rest;
129 std::getline(ss, key,
'=');
130 std::getline(ss, value,
'=');
131 std::getline(ss, rest);
132 CHECK(rest.empty()) <<
"Ill-formed LightGBM model file";
135 tree_dict.emplace_back();
138 tree_dict.back()[key] = value;
140 global_dict[key] = value;
146 auto it = global_dict.find(
"objective");
147 CHECK(it != global_dict.end())
148 <<
"Ill-formed LightGBM model file: need objective";
149 auto obj_strs = treelite::common::Split(it->second,
' ');
150 obj_name_ = obj_strs[0];
151 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
153 it = global_dict.find(
"max_feature_idx");
154 CHECK(it != global_dict.end())
155 <<
"Ill-formed LightGBM model file: need max_feature_idx";
156 max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
157 it = global_dict.find(
"num_tree_per_iteration");
158 CHECK(it != global_dict.end())
159 <<
"Ill-formed LightGBM model file: need num_tree_per_iteration";
160 num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
162 it = global_dict.find(
"average_output");
163 average_output_ = (it != global_dict.end());
166 for (
const auto& dict : tree_dict) {
167 lgb_trees_.emplace_back();
168 LGBTree& tree = lgb_trees_.back();
170 auto it = dict.find(
"num_leaves");
171 CHECK(it != dict.end())
172 <<
"Ill-formed LightGBM model file: need num_leaves";
173 tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
175 it = dict.find(
"num_cat");
176 CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
177 tree.num_cat = treelite::common::TextToNumber<int>(it->second);
179 it = dict.find(
"leaf_value");
180 CHECK(it != dict.end())
181 <<
"Ill-formed LightGBM model file: need leaf_value";
183 = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
185 it = dict.find(
"decision_type");
186 CHECK(it != dict.end())
187 <<
"Ill-formed LightGBM model file: need decision_type";
188 if (it == dict.end()) {
189 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
192 = treelite::common::TextToArray<int8_t>(it->second,
193 tree.num_leaves - 1);
196 if (tree.num_cat > 0) {
197 it = dict.find(
"cat_boundaries");
198 CHECK(it != dict.end())
199 <<
"Ill-formed LightGBM model file: need cat_boundaries";
201 = treelite::common::TextToArray<int>(it->second, tree.num_cat + 1);
202 it = dict.find(
"cat_threshold");
203 CHECK(it != dict.end())
204 <<
"Ill-formed LightGBM model file: need cat_threshold";
206 = treelite::common::TextToArray<uint32_t>(it->second,
207 tree.cat_boundaries.back());
210 it = dict.find(
"split_feature");
211 CHECK(it != dict.end())
212 <<
"Ill-formed LightGBM model file: need split_feature";
214 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
216 it = dict.find(
"threshold");
217 CHECK(it != dict.end())
218 <<
"Ill-formed LightGBM model file: need threshold";
220 = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
222 it = dict.find(
"left_child");
223 CHECK(it != dict.end())
224 <<
"Ill-formed LightGBM model file: need left_child";
226 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
228 it = dict.find(
"right_child");
229 CHECK(it != dict.end())
230 <<
"Ill-formed LightGBM model file: need right_child";
232 = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
238 model.num_output_group = num_tree_per_iteration_;
239 if (model.num_output_group > 1) {
241 CHECK(!average_output_)
242 <<
"Ill-formed LightGBM model file: cannot use random forest mode " 243 <<
"for multi-class classification";
244 model.random_forest_flag =
false;
246 model.random_forest_flag = average_output_;
250 if (obj_name_ ==
"multiclass") {
254 for (
const auto& str : obj_param_) {
255 auto tokens = treelite::common::Split(str,
':');
256 if (tokens.size() == 2 && tokens[0] ==
"num_class" 257 && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
262 CHECK(num_class >= 0 && num_class == model.num_output_group)
263 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
265 model.param.pred_transform =
"softmax";
266 }
else if (obj_name_ ==
"multiclassova") {
272 for (
const auto& str : obj_param_) {
273 auto tokens = treelite::common::Split(str,
':');
274 if (tokens.size() == 2) {
275 if (tokens[0] ==
"num_class" 276 && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
278 }
else if (tokens[0] ==
"sigmoid" 279 && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
284 CHECK(num_class >= 0 && num_class == model.num_output_group
286 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
288 model.param.pred_transform =
"multiclass_ova";
289 model.param.sigmoid_alpha = alpha;
290 }
else if (obj_name_ ==
"binary") {
294 for (
const auto& str : obj_param_) {
295 auto tokens = treelite::common::Split(str,
':');
296 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 297 && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
302 CHECK_GT(alpha, 0.0f)
303 <<
"Ill-formed LightGBM model file: not a valid binary objective";
305 model.param.pred_transform =
"sigmoid";
306 model.param.sigmoid_alpha = alpha;
307 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
308 model.param.pred_transform =
"sigmoid";
309 model.param.sigmoid_alpha = 1.0f;
310 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
311 model.param.pred_transform =
"logarithm_one_plus_exp";
313 model.param.pred_transform =
"identity";
317 for (
const auto& lgb_tree : lgb_trees_) {
318 model.trees.emplace_back();
324 std::queue<std::pair<int, int>> Q;
328 std::tie(old_id, new_id) = Q.front(); Q.pop();
330 const double leaf_value = lgb_tree.leaf_value[~old_id];
331 tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
333 const unsigned split_index =
334 static_cast<unsigned>(lgb_tree.split_feature[old_id]);
335 const bool default_left
336 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
338 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
340 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
341 const std::vector<uint32_t> left_categories
342 = BitsetToList(lgb_tree.cat_threshold.data()
343 + lgb_tree.cat_boundaries[cat_idx],
344 lgb_tree.cat_boundaries[cat_idx + 1]
345 - lgb_tree.cat_boundaries[cat_idx]);
346 tree[new_id].set_categorical_split(split_index, default_left,
353 tree[new_id].set_numerical_split(split_index, threshold,
354 default_left, cmp_op);
356 Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
357 Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
in-memory representation of a decision tree
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators