11 #include <unordered_map> 24 DMLC_REGISTRY_FILE_TAG(lightgbm);
27 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
28 *out = std::move(ParseStream(fi.get()));
38 inline T TextToNumber(
const std::string& str) {
39 static_assert(std::is_same<T, float>::value
40 || std::is_same<T, double>::value
41 || std::is_same<T, int>::value
42 || std::is_same<T, int8_t>::value
43 || std::is_same<T, uint32_t>::value
44 || std::is_same<T, uint64_t>::value,
45 "unsupported data type for TextToNumber; use float, double, " 46 "int, int8_t, uint32_t, or uint64_t");
50 inline float TextToNumber(
const std::string& str) {
53 float val = std::strtof(str.c_str(), &endptr);
54 if (errno == ERANGE) {
55 LOG(FATAL) <<
"Range error while converting string to double";
56 }
else if (errno != 0) {
57 LOG(FATAL) <<
"Unknown error";
58 }
else if (*endptr !=
'\0') {
59 LOG(FATAL) <<
"String does not represent a valid floating-point number";
65 inline double TextToNumber(
const std::string& str) {
68 double val = std::strtod(str.c_str(), &endptr);
69 if (errno == ERANGE) {
70 LOG(FATAL) <<
"Range error while converting string to double";
71 }
else if (errno != 0) {
72 LOG(FATAL) <<
"Unknown error";
73 }
else if (*endptr !=
'\0') {
74 LOG(FATAL) <<
"String does not represent a valid floating-point number";
80 inline int TextToNumber(
const std::string& str) {
83 auto val = std::strtol(str.c_str(), &endptr, 10);
84 if (errno == ERANGE || val < std::numeric_limits<int>::min()
85 || val > std::numeric_limits<int>::max()) {
86 LOG(FATAL) <<
"Range error while converting string to int";
87 }
else if (errno != 0) {
88 LOG(FATAL) <<
"Unknown error";
89 }
else if (*endptr !=
'\0') {
90 LOG(FATAL) <<
"String does not represent a valid integer";
92 return static_cast<int>(val);
96 inline int8_t TextToNumber(
const std::string& str) {
99 auto val = std::strtol(str.c_str(), &endptr, 10);
100 if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
101 || val > std::numeric_limits<int8_t>::max()) {
102 LOG(FATAL) <<
"Range error while converting string to int8_t";
103 }
else if (errno != 0) {
104 LOG(FATAL) <<
"Unknown error";
105 }
else if (*endptr !=
'\0') {
106 LOG(FATAL) <<
"String does not represent a valid integer";
108 return static_cast<int8_t
>(val);
112 inline uint32_t TextToNumber(
const std::string& str) {
115 auto val = std::strtoul(str.c_str(), &endptr, 10);
116 if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
117 LOG(FATAL) <<
"Range error while converting string to uint32_t";
118 }
else if (errno != 0) {
119 LOG(FATAL) <<
"Unknown error";
120 }
else if (*endptr !=
'\0') {
121 LOG(FATAL) <<
"String does not represent a valid integer";
123 return static_cast<uint32_t
>(val);
127 inline uint64_t TextToNumber(
const std::string& str) {
130 auto val = std::strtoull(str.c_str(), &endptr, 10);
131 if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
132 LOG(FATAL) <<
"Range error while converting string to uint64_t";
133 }
else if (errno != 0) {
134 LOG(FATAL) <<
"Unknown error";
135 }
else if (*endptr !=
'\0') {
136 LOG(FATAL) <<
"String does not represent a valid integer";
138 return static_cast<uint64_t
>(val);
141 inline std::vector<std::string> Split(
const std::string& text,
char delim) {
142 std::vector<std::string> array;
143 std::istringstream ss(text);
145 while (std::getline(ss, token, delim)) {
146 array.push_back(token);
151 template <
typename T>
152 inline std::vector<T> TextToArray(
const std::string& text,
int num_entry) {
153 if (text.empty() && num_entry > 0) {
154 LOG(FATAL) <<
"Cannot convert empty text into array";
156 std::vector<T> array;
157 std::istringstream ss(text);
159 for (
int i = 0; i < num_entry; ++i) {
160 std::getline(ss, token,
' ');
161 array.push_back(TextToNumber<T>(token));
166 enum Masks : uint8_t {
167 kCategoricalMask = 1,
171 enum class MissingType : uint8_t {
180 std::vector<double> leaf_value;
181 std::vector<int8_t> decision_type;
182 std::vector<uint64_t> cat_boundaries;
183 std::vector<uint32_t> cat_threshold;
184 std::vector<int> split_feature;
185 std::vector<double> threshold;
186 std::vector<int> left_child;
187 std::vector<int> right_child;
188 std::vector<float> split_gain;
189 std::vector<int> internal_count;
190 std::vector<int> leaf_count;
193 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
194 return (decision_type & mask) > 0;
197 inline MissingType GetMissingType(int8_t decision_type) {
198 return static_cast<MissingType
>((decision_type >> 2) & 3);
201 inline std::vector<uint32_t> BitsetToList(
const uint32_t* bits,
203 std::vector<uint32_t> result;
204 const size_t nbits = nslots * 32;
205 for (
size_t i = 0; i < nbits; ++i) {
206 const size_t i1 = i / 32;
207 const uint32_t i2 =
static_cast<uint32_t
>(i % 32);
208 if ((bits[i1] >> i2) & 1) {
209 result.push_back(static_cast<uint32_t>(i));
215 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
216 const size_t bufsize = 16 * 1024 * 1024;
217 std::vector<char> buf(bufsize);
219 std::vector<std::string> lines;
223 std::string leftover =
"";
224 while ((byte_read = fi->Read(&buf[0],
sizeof(
char) * bufsize)) > 0) {
226 size_t tok_begin = 0;
227 while (i < byte_read) {
228 if (buf[i] ==
'\n' || buf[i] ==
'\r') {
229 if (tok_begin == 0 && leftover.length() + i > 0) {
231 lines.push_back(leftover + std::string(&buf[0], i));
234 lines.emplace_back(&buf[tok_begin], i - tok_begin);
237 for (; (buf[i] ==
'\n' || buf[i] ==
'\r') && i < byte_read; ++i) {}
244 leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
247 if (!leftover.empty()) {
249 <<
"Warning: input file was not terminated with end-of-line character.";
250 lines.push_back(leftover);
257 std::vector<LGBTree> lgb_trees_;
258 int max_feature_idx_;
259 int num_tree_per_iteration_;
260 bool average_output_;
261 std::string obj_name_;
262 std::vector<std::string> obj_param_;
265 std::vector<std::string> lines = LoadText(fi);
266 std::unordered_map<std::string, std::string> global_dict;
267 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
269 bool in_tree =
false;
270 for (
const auto& line : lines) {
271 std::istringstream ss(line);
272 std::string key, value, rest;
273 std::getline(ss, key,
'=');
274 std::getline(ss, value,
'=');
275 std::getline(ss, rest);
276 CHECK(rest.empty()) <<
"Ill-formed LightGBM model file";
279 tree_dict.emplace_back();
282 tree_dict.back()[key] = value;
284 global_dict[key] = value;
290 auto it = global_dict.find(
"objective");
291 CHECK(it != global_dict.end())
292 <<
"Ill-formed LightGBM model file: need objective";
293 auto obj_strs = Split(it->second,
' ');
294 obj_name_ = obj_strs[0];
295 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
297 it = global_dict.find(
"max_feature_idx");
298 CHECK(it != global_dict.end())
299 <<
"Ill-formed LightGBM model file: need max_feature_idx";
300 max_feature_idx_ = TextToNumber<int>(it->second);
301 it = global_dict.find(
"num_tree_per_iteration");
302 CHECK(it != global_dict.end())
303 <<
"Ill-formed LightGBM model file: need num_tree_per_iteration";
304 num_tree_per_iteration_ = TextToNumber<int>(it->second);
306 it = global_dict.find(
"average_output");
307 average_output_ = (it != global_dict.end());
310 for (
const auto& dict : tree_dict) {
311 lgb_trees_.emplace_back();
312 LGBTree& tree = lgb_trees_.back();
314 auto it = dict.find(
"num_leaves");
315 CHECK(it != dict.end())
316 <<
"Ill-formed LightGBM model file: need num_leaves";
317 tree.num_leaves = TextToNumber<int>(it->second);
319 it = dict.find(
"num_cat");
320 CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
321 tree.num_cat = TextToNumber<int>(it->second);
323 it = dict.find(
"leaf_value");
324 CHECK(it != dict.end() && !it->second.empty())
325 <<
"Ill-formed LightGBM model file: need leaf_value";
327 = TextToArray<double>(it->second, tree.num_leaves);
329 it = dict.find(
"decision_type");
330 if (it == dict.end()) {
331 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
333 CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
334 <<
"Ill-formed LightGBM model file: decision_type cannot be empty string";
336 = TextToArray<int8_t>(it->second,
337 tree.num_leaves - 1);
340 if (tree.num_cat > 0) {
341 it = dict.find(
"cat_boundaries");
342 CHECK(it != dict.end() && !it->second.empty())
343 <<
"Ill-formed LightGBM model file: need cat_boundaries";
345 = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
346 it = dict.find(
"cat_threshold");
347 CHECK(it != dict.end() && !it->second.empty())
348 <<
"Ill-formed LightGBM model file: need cat_threshold";
350 = TextToArray<uint32_t>(it->second,
351 tree.cat_boundaries.back());
354 it = dict.find(
"split_feature");
355 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
356 <<
"Ill-formed LightGBM model file: need split_feature";
358 = TextToArray<int>(it->second, tree.num_leaves - 1);
360 it = dict.find(
"threshold");
361 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
362 <<
"Ill-formed LightGBM model file: need threshold";
364 = TextToArray<double>(it->second, tree.num_leaves - 1);
366 it = dict.find(
"split_gain");
367 if (it != dict.end()) {
368 CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
369 <<
"Ill-formed LightGBM model file: split_gain cannot be empty string";
371 = TextToArray<float>(it->second, tree.num_leaves - 1);
373 tree.split_gain.resize(tree.num_leaves - 1);
376 it = dict.find(
"internal_count");
377 if (it != dict.end()) {
378 CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
379 <<
"Ill-formed LightGBM model file: internal_count cannot be empty string";
381 = TextToArray<int>(it->second, tree.num_leaves - 1);
383 tree.internal_count.resize(tree.num_leaves - 1);
386 it = dict.find(
"leaf_count");
387 if (it != dict.end()) {
388 CHECK(!it->second.empty())
389 <<
"Ill-formed LightGBM model file: leaf_count cannot be empty string";
391 = TextToArray<int>(it->second, tree.num_leaves);
393 tree.leaf_count.resize(tree.num_leaves);
396 it = dict.find(
"left_child");
397 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
398 <<
"Ill-formed LightGBM model file: need left_child";
400 = TextToArray<int>(it->second, tree.num_leaves - 1);
402 it = dict.find(
"right_child");
403 CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
404 <<
"Ill-formed LightGBM model file: need right_child";
406 = TextToArray<int>(it->second, tree.num_leaves - 1);
412 model.num_output_group = num_tree_per_iteration_;
413 if (model.num_output_group > 1) {
415 CHECK(!average_output_)
416 <<
"Ill-formed LightGBM model file: cannot use random forest mode " 417 <<
"for multi-class classification";
418 model.random_forest_flag =
false;
420 model.random_forest_flag = average_output_;
424 if (obj_name_ ==
"multiclass") {
428 for (
const auto& str : obj_param_) {
429 auto tokens = Split(str,
':');
430 if (tokens.size() == 2 && tokens[0] ==
"num_class" 431 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
436 CHECK(num_class >= 0 && num_class == model.num_output_group)
437 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
439 std::strncpy(model.param.pred_transform,
"softmax",
sizeof(model.param.pred_transform));
440 }
else if (obj_name_ ==
"multiclassova") {
446 for (
const auto& str : obj_param_) {
447 auto tokens = Split(str,
':');
448 if (tokens.size() == 2) {
449 if (tokens[0] ==
"num_class" 450 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
452 }
else if (tokens[0] ==
"sigmoid" 453 && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
458 CHECK(num_class >= 0 && num_class == model.num_output_group
460 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
462 std::strncpy(model.param.pred_transform,
"multiclass_ova",
sizeof(model.param.pred_transform));
463 model.param.sigmoid_alpha = alpha;
464 }
else if (obj_name_ ==
"binary") {
468 for (
const auto& str : obj_param_) {
469 auto tokens = Split(str,
':');
470 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 471 && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
476 CHECK_GT(alpha, 0.0f)
477 <<
"Ill-formed LightGBM model file: not a valid binary objective";
479 std::strncpy(model.param.pred_transform,
"sigmoid",
sizeof(model.param.pred_transform));
480 model.param.sigmoid_alpha = alpha;
481 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
482 std::strncpy(model.param.pred_transform,
"sigmoid",
sizeof(model.param.pred_transform));
483 model.param.sigmoid_alpha = 1.0f;
484 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
485 std::strncpy(model.param.pred_transform,
"logarithm_one_plus_exp",
486 sizeof(model.param.pred_transform));
488 std::strncpy(model.param.pred_transform,
"identity",
sizeof(model.param.pred_transform));
492 for (
const auto& lgb_tree : lgb_trees_) {
493 model.trees.emplace_back();
499 std::queue<std::pair<int, int>> Q;
503 std::tie(old_id, new_id) = Q.front(); Q.pop();
505 const double leaf_value = lgb_tree.leaf_value[~old_id];
506 const int data_count = lgb_tree.leaf_count[~old_id];
507 tree.
SetLeaf(new_id, static_cast<treelite::tl_float>(leaf_value));
508 CHECK_GE(data_count, 0);
509 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
511 const int data_count = lgb_tree.internal_count[old_id];
512 const auto split_index =
513 static_cast<unsigned>(lgb_tree.split_feature[old_id]);
516 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
518 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
519 const std::vector<uint32_t> left_categories
520 = BitsetToList(lgb_tree.cat_threshold.data()
521 + lgb_tree.cat_boundaries[cat_idx],
522 lgb_tree.cat_boundaries[cat_idx + 1]
523 - lgb_tree.cat_boundaries[cat_idx]);
524 const auto missing_type
525 = GetMissingType(lgb_tree.decision_type[old_id]);
531 const bool default_left
532 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
536 CHECK_GE(data_count, 0);
537 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
538 tree.
SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
539 Q.push({lgb_tree.left_child[old_id], tree.
LeftChild(new_id)});
540 Q.push({lgb_tree.right_child[old_id], tree.
RightChild(new_id)});
544 LOG(INFO) <<
"model.num_tree = " << model.trees.size();
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
model structure for tree ensemble
in-memory representation of a decision tree
void LoadLightGBMModel(const char *filename, Model *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
void SetLeaf(int nid, tl_float value)
set the leaf value of the node
void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
Setters.
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Operator
comparison operators