11 #include <unordered_map> 19 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
28 std::ifstream fi(filename, std::ios::in);
29 return ParseStream(fi);
39 inline T TextToNumber(
const std::string& str) {
40 static_assert(std::is_same<T, float>::value
41 || std::is_same<T, double>::value
42 || std::is_same<T, int>::value
43 || std::is_same<T, int8_t>::value
44 || std::is_same<T, uint32_t>::value
45 || std::is_same<T, uint64_t>::value,
46 "unsupported data type for TextToNumber; use float, double, " 47 "int, int8_t, uint32_t, or uint64_t");
51 inline float TextToNumber(
const std::string& str) {
54 float val = std::strtof(str.c_str(), &endptr);
55 if (errno == ERANGE) {
56 TREELITE_LOG(FATAL) <<
"Range error while converting string to double";
57 }
else if (errno != 0) {
58 TREELITE_LOG(FATAL) <<
"Unknown error";
59 }
else if (*endptr !=
'\0') {
60 TREELITE_LOG(FATAL) <<
"String does not represent a valid floating-point number";
66 inline double TextToNumber(
const std::string& str) {
69 double val = std::strtod(str.c_str(), &endptr);
70 if (errno == ERANGE) {
71 TREELITE_LOG(FATAL) <<
"Range error while converting string to double";
72 }
else if (errno != 0) {
73 TREELITE_LOG(FATAL) <<
"Unknown error";
74 }
else if (*endptr !=
'\0') {
75 TREELITE_LOG(FATAL) <<
"String does not represent a valid floating-point number";
81 inline int TextToNumber(
const std::string& str) {
84 auto val = std::strtol(str.c_str(), &endptr, 10);
85 if (errno == ERANGE || val < std::numeric_limits<int>::min()
86 || val > std::numeric_limits<int>::max()) {
87 TREELITE_LOG(FATAL) <<
"Range error while converting string to int";
88 }
else if (errno != 0) {
89 TREELITE_LOG(FATAL) <<
"Unknown error";
90 }
else if (*endptr !=
'\0') {
91 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
93 return static_cast<int>(val);
97 inline int8_t TextToNumber(
const std::string& str) {
100 auto val = std::strtol(str.c_str(), &endptr, 10);
101 if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
102 || val > std::numeric_limits<int8_t>::max()) {
103 TREELITE_LOG(FATAL) <<
"Range error while converting string to int8_t";
104 }
else if (errno != 0) {
105 TREELITE_LOG(FATAL) <<
"Unknown error";
106 }
else if (*endptr !=
'\0') {
107 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
109 return static_cast<int8_t
>(val);
113 inline uint32_t TextToNumber(
const std::string& str) {
116 auto val = std::strtoul(str.c_str(), &endptr, 10);
117 if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
118 TREELITE_LOG(FATAL) <<
"Range error while converting string to uint32_t";
119 }
else if (errno != 0) {
120 TREELITE_LOG(FATAL) <<
"Unknown error";
121 }
else if (*endptr !=
'\0') {
122 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
124 return static_cast<uint32_t
>(val);
128 inline uint64_t TextToNumber(
const std::string& str) {
131 auto val = std::strtoull(str.c_str(), &endptr, 10);
132 if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
133 TREELITE_LOG(FATAL) <<
"Range error while converting string to uint64_t";
134 }
else if (errno != 0) {
135 TREELITE_LOG(FATAL) <<
"Unknown error";
136 }
else if (*endptr !=
'\0') {
137 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
139 return static_cast<uint64_t
>(val);
142 inline std::vector<std::string> Split(
const std::string& text,
char delim) {
143 std::vector<std::string> array;
144 std::istringstream ss(text);
146 while (std::getline(ss, token, delim)) {
147 array.push_back(token);
152 template <
typename T>
153 inline std::vector<T> TextToArray(
const std::string& text,
int num_entry) {
154 if (text.empty() && num_entry > 0) {
155 TREELITE_LOG(FATAL) <<
"Cannot convert empty text into array";
157 std::vector<T> array;
158 std::istringstream ss(text);
160 for (
int i = 0; i < num_entry; ++i) {
161 std::getline(ss, token,
' ');
162 array.push_back(TextToNumber<T>(token));
167 enum Masks : uint8_t {
168 kCategoricalMask = 1,
172 enum class MissingType : uint8_t {
181 std::vector<double> leaf_value;
182 std::vector<int8_t> decision_type;
183 std::vector<uint64_t> cat_boundaries;
184 std::vector<uint32_t> cat_threshold;
185 std::vector<int> split_feature;
186 std::vector<double> threshold;
187 std::vector<int> left_child;
188 std::vector<int> right_child;
189 std::vector<float> split_gain;
190 std::vector<int> internal_count;
191 std::vector<int> leaf_count;
194 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
195 return (decision_type & mask) > 0;
198 inline MissingType GetMissingType(int8_t decision_type) {
199 return static_cast<MissingType
>((decision_type >> 2) & 3);
202 inline std::vector<uint32_t> BitsetToList(
const uint32_t* bits,
204 std::vector<uint32_t> result;
205 const size_t nbits = nslots * 32;
206 for (
size_t i = 0; i < nbits; ++i) {
207 const size_t i1 = i / 32;
208 const uint32_t i2 =
static_cast<uint32_t
>(i % 32);
209 if ((bits[i1] >> i2) & 1) {
210 result.push_back(static_cast<uint32_t>(i));
216 inline std::vector<std::string> LoadText(std::istream& fi) {
217 std::vector<std::string> lines;
219 while (std::getline(fi, line)) {
220 lines.push_back(line);
226 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
227 std::vector<LGBTree> lgb_trees_;
228 int max_feature_idx_;
230 bool average_output_;
231 std::string obj_name_;
232 std::vector<std::string> obj_param_;
235 std::vector<std::string> lines = LoadText(fi);
236 std::unordered_map<std::string, std::string> global_dict;
237 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
239 bool in_tree =
false;
240 for (
const auto& line : lines) {
241 std::istringstream ss(line);
242 std::string key, value, rest;
243 std::getline(ss, key,
'=');
244 std::getline(ss, value,
'=');
245 std::getline(ss, rest);
252 tree_dict.emplace_back();
255 tree_dict.back()[key] = value;
257 global_dict[key] = value;
263 auto it = global_dict.find(
"objective");
264 if (it == global_dict.end()) {
265 obj_name_ =
"custom";
267 auto obj_strs = Split(it->second,
' ');
268 obj_name_ = obj_strs[0];
269 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
272 it = global_dict.find(
"max_feature_idx");
273 TREELITE_CHECK(it != global_dict.end())
274 <<
"Ill-formed LightGBM model file: need max_feature_idx";
275 max_feature_idx_ = TextToNumber<int>(it->second);
276 it = global_dict.find(
"num_class");
277 TREELITE_CHECK(it != global_dict.end())
278 <<
"Ill-formed LightGBM model file: need num_class";
279 num_class_ = TextToNumber<int>(it->second);
281 it = global_dict.find(
"average_output");
282 average_output_ = (it != global_dict.end());
285 for (
const auto& dict : tree_dict) {
286 lgb_trees_.emplace_back();
287 LGBTree& tree = lgb_trees_.back();
289 auto it = dict.find(
"num_leaves");
290 TREELITE_CHECK(it != dict.end())
291 <<
"Ill-formed LightGBM model file: need num_leaves";
292 tree.num_leaves = TextToNumber<int>(it->second);
294 it = dict.find(
"num_cat");
295 TREELITE_CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
296 tree.num_cat = TextToNumber<int>(it->second);
298 it = dict.find(
"leaf_value");
299 TREELITE_CHECK(it != dict.end() && !it->second.empty())
300 <<
"Ill-formed LightGBM model file: need leaf_value";
302 = TextToArray<double>(it->second, tree.num_leaves);
304 it = dict.find(
"decision_type");
305 if (tree.num_leaves <= 1) {
306 tree.decision_type = std::vector<int8_t>();
308 TREELITE_CHECK_GT(tree.num_leaves, 1);
309 if (it == dict.end()) {
310 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
312 TREELITE_CHECK(!it->second.empty())
313 <<
"Ill-formed LightGBM model file: decision_type cannot be empty string";
314 tree.decision_type = TextToArray<int8_t>(it->second, tree.num_leaves - 1);
318 if (tree.num_cat > 0) {
319 it = dict.find(
"cat_boundaries");
320 TREELITE_CHECK(it != dict.end() && !it->second.empty())
321 <<
"Ill-formed LightGBM model file: need cat_boundaries";
323 = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
324 it = dict.find(
"cat_threshold");
325 TREELITE_CHECK(it != dict.end() && !it->second.empty())
326 <<
"Ill-formed LightGBM model file: need cat_threshold";
328 = TextToArray<uint32_t>(it->second,
static_cast<uint32_t
>(tree.cat_boundaries.back()));
331 it = dict.find(
"split_feature");
332 if (tree.num_leaves <= 1) {
333 tree.split_feature = std::vector<int>();
335 TREELITE_CHECK_GT(tree.num_leaves, 1);
336 TREELITE_CHECK(it != dict.end() && !it->second.empty())
337 <<
"Ill-formed LightGBM model file: need split_feature";
338 tree.split_feature = TextToArray<int>(it->second, tree.num_leaves - 1);
341 it = dict.find(
"threshold");
342 if (tree.num_leaves <= 1) {
343 tree.threshold = std::vector<double>();
345 TREELITE_CHECK_GT(tree.num_leaves, 1);
346 TREELITE_CHECK(it != dict.end() && !it->second.empty())
347 <<
"Ill-formed LightGBM model file: need threshold";
348 tree.threshold = TextToArray<double>(it->second, tree.num_leaves - 1);
351 it = dict.find(
"split_gain");
352 if (tree.num_leaves <= 1) {
353 tree.split_gain = std::vector<float>();
355 TREELITE_CHECK_GT(tree.num_leaves, 1);
356 if (it != dict.end()) {
357 TREELITE_CHECK(!it->second.empty())
358 <<
"Ill-formed LightGBM model file: split_gain cannot be empty string";
359 tree.split_gain = TextToArray<float>(it->second, tree.num_leaves - 1);
361 tree.split_gain = std::vector<float>();
365 it = dict.find(
"internal_count");
366 if (tree.num_leaves <= 1) {
367 tree.internal_count = std::vector<int>();
369 TREELITE_CHECK_GT(tree.num_leaves, 1);
370 if (it != dict.end()) {
371 TREELITE_CHECK(!it->second.empty())
372 <<
"Ill-formed LightGBM model file: internal_count cannot be empty string";
373 tree.internal_count = TextToArray<int>(it->second, tree.num_leaves - 1);
375 tree.internal_count = std::vector<int>();
379 it = dict.find(
"leaf_count");
380 if (tree.num_leaves == 0) {
381 tree.leaf_count = std::vector<int>();
383 TREELITE_CHECK_GT(tree.num_leaves, 0);
384 if (it != dict.end() && !it->second.empty()) {
385 tree.leaf_count = TextToArray<int>(it->second, tree.num_leaves);
387 tree.leaf_count = std::vector<int>();
391 it = dict.find(
"left_child");
392 if (tree.num_leaves <= 1) {
393 tree.left_child = std::vector<int>();
395 TREELITE_CHECK_GT(tree.num_leaves, 1);
396 TREELITE_CHECK(it != dict.end() && !it->second.empty())
397 <<
"Ill-formed LightGBM model file: need left_child";
398 tree.left_child = TextToArray<int>(it->second, tree.num_leaves - 1);
401 it = dict.find(
"right_child");
402 if (tree.num_leaves <= 1) {
403 tree.right_child = std::vector<int>();
405 TREELITE_CHECK_GT(tree.num_leaves, 1);
406 TREELITE_CHECK(it != dict.end() && !it->second.empty())
407 <<
"Ill-formed LightGBM model file: need right_child";
408 tree.right_child = TextToArray<int>(it->second, tree.num_leaves - 1);
413 std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<double, double>();
416 model->average_tree_output = average_output_;
417 if (num_class_ > 1) {
419 model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
420 model->task_param.grove_per_class =
true;
423 model->task_type = treelite::TaskType::kBinaryClfRegr;
424 model->task_param.grove_per_class =
false;
426 model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
427 model->task_param.num_class = num_class_;
428 model->task_param.leaf_vector_size = 1;
431 if (obj_name_ ==
"multiclass") {
435 for (
const auto& str : obj_param_) {
436 auto tokens = Split(str,
':');
437 if (tokens.size() == 2 && tokens[0] ==
"num_class" 438 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
443 TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class)
444 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
446 std::strncpy(model->param.pred_transform,
"softmax",
sizeof(model->param.pred_transform));
447 }
else if (obj_name_ ==
"multiclassova") {
453 for (
const auto& str : obj_param_) {
454 auto tokens = Split(str,
':');
455 if (tokens.size() == 2) {
456 if (tokens[0] ==
"num_class" 457 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
459 }
else if (tokens[0] ==
"sigmoid" 460 && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
465 TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class
467 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
469 std::strncpy(model->param.pred_transform,
"multiclass_ova",
470 sizeof(model->param.pred_transform));
471 model->param.sigmoid_alpha = alpha;
472 }
else if (obj_name_ ==
"binary") {
476 for (
const auto& str : obj_param_) {
477 auto tokens = Split(str,
':');
478 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 479 && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
484 TREELITE_CHECK_GT(alpha, 0.0f)
485 <<
"Ill-formed LightGBM model file: not a valid binary objective";
487 std::strncpy(model->param.pred_transform,
"sigmoid",
sizeof(model->param.pred_transform));
488 model->param.sigmoid_alpha = alpha;
489 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
490 std::strncpy(model->param.pred_transform,
"sigmoid",
sizeof(model->param.pred_transform));
491 model->param.sigmoid_alpha = 1.0f;
492 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
493 std::strncpy(model->param.pred_transform,
"logarithm_one_plus_exp",
494 sizeof(model->param.pred_transform));
495 }
else if (obj_name_ ==
"poisson" || obj_name_ ==
"gamma" || obj_name_ ==
"tweedie") {
496 std::strncpy(model->param.pred_transform,
"exponential",
497 sizeof(model->param.pred_transform));
498 }
else if (obj_name_ ==
"regression" || obj_name_ ==
"regression_l1" || obj_name_ ==
"huber" 499 || obj_name_ ==
"fair" || obj_name_ ==
"quantile" || obj_name_ ==
"mape") {
501 bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(),
"sqrt") != obj_param_.cend());
503 std::strncpy(model->param.pred_transform,
"signed_square",
504 sizeof(model->param.pred_transform));
506 std::strncpy(model->param.pred_transform,
"identity",
507 sizeof(model->param.pred_transform));
509 }
else if (obj_name_ ==
"lambdarank" || obj_name_ ==
"rank_xendcg" || obj_name_ ==
"custom") {
511 std::strncpy(model->param.pred_transform,
"identity",
512 sizeof(model->param.pred_transform));
514 TREELITE_LOG(FATAL) <<
"Unrecognized objective: " << obj_name_;
518 for (
const auto& lgb_tree : lgb_trees_) {
519 model->trees.emplace_back();
525 std::queue<std::pair<int, int>> Q;
526 if (lgb_tree.num_leaves == 0) {
528 }
else if (lgb_tree.num_leaves == 1) {
536 std::tie(old_id, new_id) = Q.front(); Q.pop();
538 const double leaf_value = lgb_tree.leaf_value[~old_id];
539 tree.
SetLeaf(new_id, static_cast<double>(leaf_value));
540 if (!lgb_tree.leaf_count.empty()) {
541 const int data_count = lgb_tree.leaf_count[~old_id];
542 TREELITE_CHECK_GE(data_count, 0);
543 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
546 const auto split_index =
static_cast<unsigned>(lgb_tree.split_feature[old_id]);
547 const auto missing_type = GetMissingType(lgb_tree.decision_type[old_id]);
550 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
552 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
553 const std::vector<uint32_t> left_categories
554 = BitsetToList(lgb_tree.cat_threshold.data()
555 + lgb_tree.cat_boundaries[cat_idx],
556 lgb_tree.cat_boundaries[cat_idx + 1]
557 - lgb_tree.cat_boundaries[cat_idx]);
558 const bool missing_value_to_zero = missing_type != MissingType::kNaN;
559 bool default_left =
false;
560 if (missing_value_to_zero) {
564 = (std::find(left_categories.begin(), left_categories.end(),
565 static_cast<uint32_t
>(0)) != left_categories.end());
570 const auto threshold =
static_cast<double>(lgb_tree.threshold[old_id]);
572 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
574 const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
575 if (missing_value_to_zero) {
578 default_left = 0.0 <= threshold;
582 if (!lgb_tree.internal_count.empty()) {
583 const int data_count = lgb_tree.internal_count[old_id];
584 TREELITE_CHECK_GE(data_count, 0);
585 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
587 if (!lgb_tree.split_gain.empty()) {
588 tree.
SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
590 Q.push({lgb_tree.left_child[old_id], tree.
LeftChild(new_id)});
591 Q.push({lgb_tree.right_child[old_id], tree.
RightChild(new_id)});
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
model structure for tree ensemble
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
in-memory representation of a decision tree
logging facility for Treelite
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Operator
comparison operators
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node