11 #include <unordered_map> 17 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi);
24 DMLC_REGISTRY_FILE_TAG(lightgbm);
27 std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename,
"r"));
28 return ParseStream(fi.get());
38 inline T TextToNumber(
const std::string& str) {
39 static_assert(std::is_same<T, float>::value
40 || std::is_same<T, double>::value
41 || std::is_same<T, int>::value
42 || std::is_same<T, int8_t>::value
43 || std::is_same<T, uint32_t>::value
44 || std::is_same<T, uint64_t>::value,
45 "unsupported data type for TextToNumber; use float, double, " 46 "int, int8_t, uint32_t, or uint64_t");
50 inline float TextToNumber(
const std::string& str) {
53 float val = std::strtof(str.c_str(), &endptr);
54 if (errno == ERANGE) {
55 LOG(FATAL) <<
"Range error while converting string to double";
56 }
else if (errno != 0) {
57 LOG(FATAL) <<
"Unknown error";
58 }
else if (*endptr !=
'\0') {
59 LOG(FATAL) <<
"String does not represent a valid floating-point number";
65 inline double TextToNumber(
const std::string& str) {
68 double val = std::strtod(str.c_str(), &endptr);
69 if (errno == ERANGE) {
70 LOG(FATAL) <<
"Range error while converting string to double";
71 }
else if (errno != 0) {
72 LOG(FATAL) <<
"Unknown error";
73 }
else if (*endptr !=
'\0') {
74 LOG(FATAL) <<
"String does not represent a valid floating-point number";
80 inline int TextToNumber(
const std::string& str) {
83 auto val = std::strtol(str.c_str(), &endptr, 10);
84 if (errno == ERANGE || val < std::numeric_limits<int>::min()
85 || val > std::numeric_limits<int>::max()) {
86 LOG(FATAL) <<
"Range error while converting string to int";
87 }
else if (errno != 0) {
88 LOG(FATAL) <<
"Unknown error";
89 }
else if (*endptr !=
'\0') {
90 LOG(FATAL) <<
"String does not represent a valid integer";
92 return static_cast<int>(val);
96 inline int8_t TextToNumber(
const std::string& str) {
99 auto val = std::strtol(str.c_str(), &endptr, 10);
100 if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
101 || val > std::numeric_limits<int8_t>::max()) {
102 LOG(FATAL) <<
"Range error while converting string to int8_t";
103 }
else if (errno != 0) {
104 LOG(FATAL) <<
"Unknown error";
105 }
else if (*endptr !=
'\0') {
106 LOG(FATAL) <<
"String does not represent a valid integer";
108 return static_cast<int8_t
>(val);
112 inline uint32_t TextToNumber(
const std::string& str) {
115 auto val = std::strtoul(str.c_str(), &endptr, 10);
116 if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
117 LOG(FATAL) <<
"Range error while converting string to uint32_t";
118 }
else if (errno != 0) {
119 LOG(FATAL) <<
"Unknown error";
120 }
else if (*endptr !=
'\0') {
121 LOG(FATAL) <<
"String does not represent a valid integer";
123 return static_cast<uint32_t
>(val);
127 inline uint64_t TextToNumber(
const std::string& str) {
130 auto val = std::strtoull(str.c_str(), &endptr, 10);
131 if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
132 LOG(FATAL) <<
"Range error while converting string to uint64_t";
133 }
else if (errno != 0) {
134 LOG(FATAL) <<
"Unknown error";
135 }
else if (*endptr !=
'\0') {
136 LOG(FATAL) <<
"String does not represent a valid integer";
138 return static_cast<uint64_t
>(val);
141 inline std::vector<std::string> Split(
const std::string& text,
char delim) {
142 std::vector<std::string> array;
143 std::istringstream ss(text);
145 while (std::getline(ss, token, delim)) {
146 array.push_back(token);
151 template <
typename T>
152 inline std::vector<T> TextToArray(
const std::string& text,
int num_entry) {
153 if (text.empty() && num_entry > 0) {
154 LOG(FATAL) <<
"Cannot convert empty text into array";
156 std::vector<T> array;
157 std::istringstream ss(text);
159 for (
int i = 0; i < num_entry; ++i) {
160 std::getline(ss, token,
' ');
161 array.push_back(TextToNumber<T>(token));
166 enum Masks : uint8_t {
167 kCategoricalMask = 1,
171 enum class MissingType : uint8_t {
180 std::vector<double> leaf_value;
181 std::vector<int8_t> decision_type;
182 std::vector<uint64_t> cat_boundaries;
183 std::vector<uint32_t> cat_threshold;
184 std::vector<int> split_feature;
185 std::vector<double> threshold;
186 std::vector<int> left_child;
187 std::vector<int> right_child;
188 std::vector<float> split_gain;
189 std::vector<int> internal_count;
190 std::vector<int> leaf_count;
193 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
194 return (decision_type & mask) > 0;
197 inline MissingType GetMissingType(int8_t decision_type) {
198 return static_cast<MissingType
>((decision_type >> 2) & 3);
201 inline std::vector<uint32_t> BitsetToList(
const uint32_t* bits,
203 std::vector<uint32_t> result;
204 const size_t nbits = nslots * 32;
205 for (
size_t i = 0; i < nbits; ++i) {
206 const size_t i1 = i / 32;
207 const uint32_t i2 =
static_cast<uint32_t
>(i % 32);
208 if ((bits[i1] >> i2) & 1) {
209 result.push_back(static_cast<uint32_t>(i));
215 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
216 const size_t bufsize = 16 * 1024 * 1024;
217 std::vector<char> buf(bufsize);
219 std::vector<std::string> lines;
223 std::string leftover =
"";
224 while ((byte_read = fi->Read(&buf[0],
sizeof(
char) * bufsize)) > 0) {
226 size_t tok_begin = 0;
227 while (i < byte_read) {
228 if (buf[i] ==
'\n' || buf[i] ==
'\r') {
229 if (tok_begin == 0 && leftover.length() + i > 0) {
231 lines.push_back(leftover + std::string(&buf[0], i));
234 lines.emplace_back(&buf[tok_begin], i - tok_begin);
237 for (; (buf[i] ==
'\n' || buf[i] ==
'\r') && i < byte_read; ++i) {}
244 leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
247 if (!leftover.empty()) {
249 <<
"Warning: input file was not terminated with end-of-line character.";
250 lines.push_back(leftover);
256 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
257 std::vector<LGBTree> lgb_trees_;
258 int max_feature_idx_;
260 bool average_output_;
261 std::string obj_name_;
262 std::vector<std::string> obj_param_;
265 std::vector<std::string> lines = LoadText(fi);
266 std::unordered_map<std::string, std::string> global_dict;
267 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
269 bool in_tree =
false;
270 for (
const auto& line : lines) {
271 std::istringstream ss(line);
272 std::string key, value, rest;
273 std::getline(ss, key,
'=');
274 std::getline(ss, value,
'=');
275 std::getline(ss, rest);
276 CHECK(rest.empty()) <<
"Ill-formed LightGBM model file";
279 tree_dict.emplace_back();
282 tree_dict.back()[key] = value;
284 global_dict[key] = value;
290 auto it = global_dict.find(
"objective");
291 if (it == global_dict.end()) {
292 obj_name_ =
"custom";
294 auto obj_strs = Split(it->second,
' ');
295 obj_name_ = obj_strs[0];
296 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
299 it = global_dict.find(
"max_feature_idx");
300 CHECK(it != global_dict.end())
301 <<
"Ill-formed LightGBM model file: need max_feature_idx";
302 max_feature_idx_ = TextToNumber<int>(it->second);
303 it = global_dict.find(
"num_class");
304 CHECK(it != global_dict.end())
305 <<
"Ill-formed LightGBM model file: need num_class";
306 num_class_ = TextToNumber<int>(it->second);
308 it = global_dict.find(
"average_output");
309 average_output_ = (it != global_dict.end());
312 for (
const auto& dict : tree_dict) {
313 lgb_trees_.emplace_back();
314 LGBTree& tree = lgb_trees_.back();
316 auto it = dict.find(
"num_leaves");
317 CHECK(it != dict.end())
318 <<
"Ill-formed LightGBM model file: need num_leaves";
319 tree.num_leaves = TextToNumber<int>(it->second);
321 it = dict.find(
"num_cat");
322 CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
323 tree.num_cat = TextToNumber<int>(it->second);
325 it = dict.find(
"leaf_value");
326 CHECK(it != dict.end() && !it->second.empty())
327 <<
"Ill-formed LightGBM model file: need leaf_value";
329 = TextToArray<double>(it->second, tree.num_leaves);
331 it = dict.find(
"decision_type");
332 if (tree.num_leaves <= 1) {
333 tree.decision_type = std::vector<int8_t>();
335 CHECK_GT(tree.num_leaves, 1);
336 if (it == dict.end()) {
337 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
339 CHECK(!it->second.empty())
340 <<
"Ill-formed LightGBM model file: decision_type cannot be empty string";
341 tree.decision_type = TextToArray<int8_t>(it->second, tree.num_leaves - 1);
345 if (tree.num_cat > 0) {
346 it = dict.find(
"cat_boundaries");
347 CHECK(it != dict.end() && !it->second.empty())
348 <<
"Ill-formed LightGBM model file: need cat_boundaries";
350 = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
351 it = dict.find(
"cat_threshold");
352 CHECK(it != dict.end() && !it->second.empty())
353 <<
"Ill-formed LightGBM model file: need cat_threshold";
355 = TextToArray<uint32_t>(it->second,
static_cast<uint32_t
>(tree.cat_boundaries.back()));
358 it = dict.find(
"split_feature");
359 if (tree.num_leaves <= 1) {
360 tree.split_feature = std::vector<int>();
362 CHECK_GT(tree.num_leaves, 1);
363 CHECK(it != dict.end() && !it->second.empty())
364 <<
"Ill-formed LightGBM model file: need split_feature";
365 tree.split_feature = TextToArray<int>(it->second, tree.num_leaves - 1);
368 it = dict.find(
"threshold");
369 if (tree.num_leaves <= 1) {
370 tree.threshold = std::vector<double>();
372 CHECK_GT(tree.num_leaves, 1);
373 CHECK(it != dict.end() && !it->second.empty())
374 <<
"Ill-formed LightGBM model file: need threshold";
375 tree.threshold = TextToArray<double>(it->second, tree.num_leaves - 1);
378 it = dict.find(
"split_gain");
379 if (tree.num_leaves <= 1) {
380 tree.split_gain = std::vector<float>();
382 CHECK_GT(tree.num_leaves, 1);
383 if (it != dict.end()) {
384 CHECK(!it->second.empty())
385 <<
"Ill-formed LightGBM model file: split_gain cannot be empty string";
386 tree.split_gain = TextToArray<float>(it->second, tree.num_leaves - 1);
388 tree.split_gain = std::vector<float>();
392 it = dict.find(
"internal_count");
393 if (tree.num_leaves <= 1) {
394 tree.internal_count = std::vector<int>();
396 CHECK_GT(tree.num_leaves, 1);
397 if (it != dict.end()) {
398 CHECK(!it->second.empty())
399 <<
"Ill-formed LightGBM model file: internal_count cannot be empty string";
400 tree.internal_count = TextToArray<int>(it->second, tree.num_leaves - 1);
402 tree.internal_count = std::vector<int>();
406 it = dict.find(
"leaf_count");
407 if (tree.num_leaves == 0) {
408 tree.leaf_count = std::vector<int>();
410 CHECK_GT(tree.num_leaves, 0);
411 if (it != dict.end() && !it->second.empty()) {
412 tree.leaf_count = TextToArray<int>(it->second, tree.num_leaves);
414 tree.leaf_count = std::vector<int>();
418 it = dict.find(
"left_child");
419 if (tree.num_leaves <= 1) {
420 tree.left_child = std::vector<int>();
422 CHECK_GT(tree.num_leaves, 1);
423 CHECK(it != dict.end() && !it->second.empty())
424 <<
"Ill-formed LightGBM model file: need left_child";
425 tree.left_child = TextToArray<int>(it->second, tree.num_leaves - 1);
428 it = dict.find(
"right_child");
429 if (tree.num_leaves <= 1) {
430 tree.right_child = std::vector<int>();
432 CHECK_GT(tree.num_leaves, 1);
433 CHECK(it != dict.end() && !it->second.empty())
434 <<
"Ill-formed LightGBM model file: need right_child";
435 tree.right_child = TextToArray<int>(it->second, tree.num_leaves - 1);
440 std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<double, double>();
443 model->average_tree_output = average_output_;
444 if (num_class_ > 1) {
446 model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
447 model->task_param.grove_per_class =
true;
450 model->task_type = treelite::TaskType::kBinaryClfRegr;
451 model->task_param.grove_per_class =
false;
453 model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat;
454 model->task_param.num_class = num_class_;
455 model->task_param.leaf_vector_size = 1;
458 if (obj_name_ ==
"multiclass") {
462 for (
const auto& str : obj_param_) {
463 auto tokens = Split(str,
':');
464 if (tokens.size() == 2 && tokens[0] ==
"num_class" 465 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
470 CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class)
471 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
473 std::strncpy(model->param.pred_transform,
"softmax",
sizeof(model->param.pred_transform));
474 }
else if (obj_name_ ==
"multiclassova") {
480 for (
const auto& str : obj_param_) {
481 auto tokens = Split(str,
':');
482 if (tokens.size() == 2) {
483 if (tokens[0] ==
"num_class" 484 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
486 }
else if (tokens[0] ==
"sigmoid" 487 && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
492 CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class
494 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
496 std::strncpy(model->param.pred_transform,
"multiclass_ova",
497 sizeof(model->param.pred_transform));
498 model->param.sigmoid_alpha = alpha;
499 }
else if (obj_name_ ==
"binary") {
503 for (
const auto& str : obj_param_) {
504 auto tokens = Split(str,
':');
505 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 506 && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
511 CHECK_GT(alpha, 0.0f)
512 <<
"Ill-formed LightGBM model file: not a valid binary objective";
514 std::strncpy(model->param.pred_transform,
"sigmoid",
sizeof(model->param.pred_transform));
515 model->param.sigmoid_alpha = alpha;
516 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
517 std::strncpy(model->param.pred_transform,
"sigmoid",
sizeof(model->param.pred_transform));
518 model->param.sigmoid_alpha = 1.0f;
519 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
520 std::strncpy(model->param.pred_transform,
"logarithm_one_plus_exp",
521 sizeof(model->param.pred_transform));
522 }
else if (obj_name_ ==
"poisson" || obj_name_ ==
"gamma" || obj_name_ ==
"tweedie") {
523 std::strncpy(model->param.pred_transform,
"exponential",
524 sizeof(model->param.pred_transform));
525 }
else if (obj_name_ ==
"regression" || obj_name_ ==
"regression_l1" || obj_name_ ==
"huber" 526 || obj_name_ ==
"fair" || obj_name_ ==
"quantile" || obj_name_ ==
"mape") {
528 bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(),
"sqrt") != obj_param_.cend());
530 std::strncpy(model->param.pred_transform,
"signed_square",
531 sizeof(model->param.pred_transform));
533 std::strncpy(model->param.pred_transform,
"identity",
534 sizeof(model->param.pred_transform));
536 }
else if (obj_name_ ==
"lambdarank" || obj_name_ ==
"rank_xendcg" || obj_name_ ==
"custom") {
538 std::strncpy(model->param.pred_transform,
"identity",
539 sizeof(model->param.pred_transform));
541 LOG(FATAL) <<
"Unrecognized objective: " << obj_name_;
545 for (
const auto& lgb_tree : lgb_trees_) {
546 model->trees.emplace_back();
552 std::queue<std::pair<int, int>> Q;
553 if (lgb_tree.num_leaves == 0) {
555 }
else if (lgb_tree.num_leaves == 1) {
563 std::tie(old_id, new_id) = Q.front(); Q.pop();
565 const double leaf_value = lgb_tree.leaf_value[~old_id];
566 tree.
SetLeaf(new_id, static_cast<double>(leaf_value));
567 if (!lgb_tree.leaf_count.empty()) {
568 const int data_count = lgb_tree.leaf_count[~old_id];
569 CHECK_GE(data_count, 0);
570 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
573 const auto split_index =
static_cast<unsigned>(lgb_tree.split_feature[old_id]);
574 const auto missing_type = GetMissingType(lgb_tree.decision_type[old_id]);
577 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
579 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
580 const std::vector<uint32_t> left_categories
581 = BitsetToList(lgb_tree.cat_threshold.data()
582 + lgb_tree.cat_boundaries[cat_idx],
583 lgb_tree.cat_boundaries[cat_idx + 1]
584 - lgb_tree.cat_boundaries[cat_idx]);
585 const bool missing_value_to_zero = missing_type != MissingType::kNaN;
586 bool default_left =
false;
587 if (missing_value_to_zero) {
591 = (std::find(left_categories.begin(), left_categories.end(),
592 static_cast<uint32_t
>(0)) != left_categories.end());
597 const auto threshold =
static_cast<double>(lgb_tree.threshold[old_id]);
599 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
601 const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
602 if (missing_value_to_zero) {
605 default_left = 0.0 <= threshold;
609 if (!lgb_tree.internal_count.empty()) {
610 const int data_count = lgb_tree.internal_count[old_id];
611 CHECK_GE(data_count, 0);
612 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
614 if (!lgb_tree.split_gain.empty()) {
615 tree.
SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
617 Q.push({lgb_tree.left_child[old_id], tree.
LeftChild(new_id)});
618 Q.push({lgb_tree.right_child[old_id], tree.
RightChild(new_id)});
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
model structure for tree ensemble
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
in-memory representation of a decision tree
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Operator
comparison operators
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node