11 #include <unordered_map> 20 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
29 std::ifstream fi(filename, std::ios::in);
30 return ParseStream(fi);
34 std::istringstream is(model_str);
35 return ParseStream(is);
45 inline T TextToNumber(
const std::string& str) {
46 static_assert(std::is_same<T, float>::value
47 || std::is_same<T, double>::value
48 || std::is_same<T, int>::value
49 || std::is_same<T, int8_t>::value
50 || std::is_same<T, uint32_t>::value
51 || std::is_same<T, uint64_t>::value,
52 "unsupported data type for TextToNumber; use float, double, " 53 "int, int8_t, uint32_t, or uint64_t");
57 inline float TextToNumber(
const std::string& str) {
60 float val = std::strtof(str.c_str(), &endptr);
61 if (errno == ERANGE) {
62 TREELITE_LOG(FATAL) <<
"Range error while converting string to double";
63 }
else if (errno != 0) {
64 TREELITE_LOG(FATAL) <<
"Unknown error";
65 }
else if (*endptr !=
'\0') {
66 TREELITE_LOG(FATAL) <<
"String does not represent a valid floating-point number";
72 inline double TextToNumber(
const std::string& str) {
75 double val = std::strtod(str.c_str(), &endptr);
76 if (errno == ERANGE) {
77 TREELITE_LOG(FATAL) <<
"Range error while converting string to double";
78 }
else if (errno != 0) {
79 TREELITE_LOG(FATAL) <<
"Unknown error";
80 }
else if (*endptr !=
'\0') {
81 TREELITE_LOG(FATAL) <<
"String does not represent a valid floating-point number";
87 inline int TextToNumber(
const std::string& str) {
90 auto val = std::strtol(str.c_str(), &endptr, 10);
91 if (errno == ERANGE || val < std::numeric_limits<int>::min()
92 || val > std::numeric_limits<int>::max()) {
93 TREELITE_LOG(FATAL) <<
"Range error while converting string to int";
94 }
else if (errno != 0) {
95 TREELITE_LOG(FATAL) <<
"Unknown error";
96 }
else if (*endptr !=
'\0') {
97 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
99 return static_cast<int>(val);
103 inline int8_t TextToNumber(
const std::string& str) {
106 auto val = std::strtol(str.c_str(), &endptr, 10);
107 if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
108 || val > std::numeric_limits<int8_t>::max()) {
109 TREELITE_LOG(FATAL) <<
"Range error while converting string to int8_t";
110 }
else if (errno != 0) {
111 TREELITE_LOG(FATAL) <<
"Unknown error";
112 }
else if (*endptr !=
'\0') {
113 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
115 return static_cast<int8_t
>(val);
119 inline uint32_t TextToNumber(
const std::string& str) {
122 auto val = std::strtoul(str.c_str(), &endptr, 10);
123 if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
124 TREELITE_LOG(FATAL) <<
"Range error while converting string to uint32_t";
125 }
else if (errno != 0) {
126 TREELITE_LOG(FATAL) <<
"Unknown error";
127 }
else if (*endptr !=
'\0') {
128 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
130 return static_cast<uint32_t
>(val);
134 inline uint64_t TextToNumber(
const std::string& str) {
137 auto val = std::strtoull(str.c_str(), &endptr, 10);
138 if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
139 TREELITE_LOG(FATAL) <<
"Range error while converting string to uint64_t";
140 }
else if (errno != 0) {
141 TREELITE_LOG(FATAL) <<
"Unknown error";
142 }
else if (*endptr !=
'\0') {
143 TREELITE_LOG(FATAL) <<
"String does not represent a valid integer";
145 return static_cast<uint64_t
>(val);
148 inline std::vector<std::string> Split(
const std::string& text,
char delim) {
149 std::vector<std::string> array;
150 std::istringstream ss(text);
152 while (std::getline(ss, token, delim)) {
153 array.push_back(token);
158 template <
typename T>
159 inline std::vector<T> TextToArray(
const std::string& text,
int num_entry) {
160 if (text.empty() && num_entry > 0) {
161 TREELITE_LOG(FATAL) <<
"Cannot convert empty text into array";
163 std::vector<T> array;
164 std::istringstream ss(text);
166 for (
int i = 0; i < num_entry; ++i) {
167 std::getline(ss, token,
' ');
168 array.push_back(TextToNumber<T>(token));
173 enum Masks : uint8_t {
174 kCategoricalMask = 1,
178 enum class MissingType : uint8_t {
187 std::vector<double> leaf_value;
188 std::vector<int8_t> decision_type;
189 std::vector<uint64_t> cat_boundaries;
190 std::vector<uint32_t> cat_threshold;
191 std::vector<int> split_feature;
192 std::vector<double> threshold;
193 std::vector<int> left_child;
194 std::vector<int> right_child;
195 std::vector<float> split_gain;
196 std::vector<int> internal_count;
197 std::vector<int> leaf_count;
200 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
201 return (decision_type & mask) > 0;
204 inline MissingType GetMissingType(int8_t decision_type) {
205 return static_cast<MissingType
>((decision_type >> 2) & 3);
208 inline std::vector<uint32_t> BitsetToList(
const uint32_t* bits,
210 std::vector<uint32_t> result;
211 const size_t nbits = nslots * 32;
212 for (
size_t i = 0; i < nbits; ++i) {
213 const size_t i1 = i / 32;
214 const uint32_t i2 =
static_cast<uint32_t
>(i % 32);
215 if ((bits[i1] >> i2) & 1) {
216 result.push_back(static_cast<uint32_t>(i));
222 inline std::vector<std::string> LoadText(std::istream& fi) {
223 std::vector<std::string> lines;
225 while (std::getline(fi, line)) {
226 lines.push_back(line);
232 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
233 std::vector<LGBTree> lgb_trees_;
234 int max_feature_idx_;
236 bool average_output_;
237 std::string obj_name_;
238 std::vector<std::string> obj_param_;
241 std::vector<std::string> lines = LoadText(fi);
242 std::unordered_map<std::string, std::string> global_dict;
243 std::vector<std::unordered_map<std::string, std::string>> tree_dict;
245 bool in_tree =
false;
246 for (
const auto& line : lines) {
247 std::istringstream ss(line);
248 std::string key, value, rest;
249 std::getline(ss, key,
'=');
250 std::getline(ss, value,
'=');
251 std::getline(ss, rest);
258 tree_dict.emplace_back();
261 tree_dict.back()[key] = value;
263 global_dict[key] = value;
269 auto it = global_dict.find(
"objective");
270 if (it == global_dict.end()) {
271 obj_name_ =
"custom";
273 auto obj_strs = Split(it->second,
' ');
274 obj_name_ = obj_strs[0];
275 obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
278 it = global_dict.find(
"max_feature_idx");
279 TREELITE_CHECK(it != global_dict.end())
280 <<
"Ill-formed LightGBM model file: need max_feature_idx";
281 max_feature_idx_ = TextToNumber<int>(it->second);
282 it = global_dict.find(
"num_class");
283 TREELITE_CHECK(it != global_dict.end())
284 <<
"Ill-formed LightGBM model file: need num_class";
285 num_class_ = TextToNumber<int>(it->second);
287 it = global_dict.find(
"average_output");
288 average_output_ = (it != global_dict.end());
291 for (
const auto& dict : tree_dict) {
292 lgb_trees_.emplace_back();
293 LGBTree& tree = lgb_trees_.back();
295 auto it = dict.find(
"num_leaves");
296 TREELITE_CHECK(it != dict.end())
297 <<
"Ill-formed LightGBM model file: need num_leaves";
298 tree.num_leaves = TextToNumber<int>(it->second);
300 it = dict.find(
"num_cat");
301 TREELITE_CHECK(it != dict.end()) <<
"Ill-formed LightGBM model file: need num_cat";
302 tree.num_cat = TextToNumber<int>(it->second);
304 it = dict.find(
"leaf_value");
305 TREELITE_CHECK(it != dict.end() && !it->second.empty())
306 <<
"Ill-formed LightGBM model file: need leaf_value";
308 = TextToArray<double>(it->second, tree.num_leaves);
310 it = dict.find(
"decision_type");
311 if (tree.num_leaves <= 1) {
312 tree.decision_type = std::vector<int8_t>();
314 TREELITE_CHECK_GT(tree.num_leaves, 1);
315 if (it == dict.end()) {
316 tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
318 TREELITE_CHECK(!it->second.empty())
319 <<
"Ill-formed LightGBM model file: decision_type cannot be empty string";
320 tree.decision_type = TextToArray<int8_t>(it->second, tree.num_leaves - 1);
324 if (tree.num_cat > 0) {
325 it = dict.find(
"cat_boundaries");
326 TREELITE_CHECK(it != dict.end() && !it->second.empty())
327 <<
"Ill-formed LightGBM model file: need cat_boundaries";
329 = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
330 it = dict.find(
"cat_threshold");
331 TREELITE_CHECK(it != dict.end() && !it->second.empty())
332 <<
"Ill-formed LightGBM model file: need cat_threshold";
334 = TextToArray<uint32_t>(it->second,
static_cast<uint32_t
>(tree.cat_boundaries.back()));
337 it = dict.find(
"split_feature");
338 if (tree.num_leaves <= 1) {
339 tree.split_feature = std::vector<int>();
341 TREELITE_CHECK_GT(tree.num_leaves, 1);
342 TREELITE_CHECK(it != dict.end() && !it->second.empty())
343 <<
"Ill-formed LightGBM model file: need split_feature";
344 tree.split_feature = TextToArray<int>(it->second, tree.num_leaves - 1);
347 it = dict.find(
"threshold");
348 if (tree.num_leaves <= 1) {
349 tree.threshold = std::vector<double>();
351 TREELITE_CHECK_GT(tree.num_leaves, 1);
352 TREELITE_CHECK(it != dict.end() && !it->second.empty())
353 <<
"Ill-formed LightGBM model file: need threshold";
354 tree.threshold = TextToArray<double>(it->second, tree.num_leaves - 1);
357 it = dict.find(
"split_gain");
358 if (tree.num_leaves <= 1) {
359 tree.split_gain = std::vector<float>();
361 TREELITE_CHECK_GT(tree.num_leaves, 1);
362 if (it != dict.end()) {
363 TREELITE_CHECK(!it->second.empty())
364 <<
"Ill-formed LightGBM model file: split_gain cannot be empty string";
365 tree.split_gain = TextToArray<float>(it->second, tree.num_leaves - 1);
367 tree.split_gain = std::vector<float>();
371 it = dict.find(
"internal_count");
372 if (tree.num_leaves <= 1) {
373 tree.internal_count = std::vector<int>();
375 TREELITE_CHECK_GT(tree.num_leaves, 1);
376 if (it != dict.end()) {
377 TREELITE_CHECK(!it->second.empty())
378 <<
"Ill-formed LightGBM model file: internal_count cannot be empty string";
379 tree.internal_count = TextToArray<int>(it->second, tree.num_leaves - 1);
381 tree.internal_count = std::vector<int>();
385 it = dict.find(
"leaf_count");
386 if (tree.num_leaves == 0) {
387 tree.leaf_count = std::vector<int>();
389 TREELITE_CHECK_GT(tree.num_leaves, 0);
390 if (it != dict.end() && !it->second.empty()) {
391 tree.leaf_count = TextToArray<int>(it->second, tree.num_leaves);
393 tree.leaf_count = std::vector<int>();
397 it = dict.find(
"left_child");
398 if (tree.num_leaves <= 1) {
399 tree.left_child = std::vector<int>();
401 TREELITE_CHECK_GT(tree.num_leaves, 1);
402 TREELITE_CHECK(it != dict.end() && !it->second.empty())
403 <<
"Ill-formed LightGBM model file: need left_child";
404 tree.left_child = TextToArray<int>(it->second, tree.num_leaves - 1);
407 it = dict.find(
"right_child");
408 if (tree.num_leaves <= 1) {
409 tree.right_child = std::vector<int>();
411 TREELITE_CHECK_GT(tree.num_leaves, 1);
412 TREELITE_CHECK(it != dict.end() && !it->second.empty())
413 <<
"Ill-formed LightGBM model file: need right_child";
414 tree.right_child = TextToArray<int>(it->second, tree.num_leaves - 1);
419 std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<double, double>();
422 model->average_tree_output = average_output_;
423 if (num_class_ > 1) {
425 model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
426 model->task_param.grove_per_class =
true;
429 model->task_type = treelite::TaskType::kBinaryClfRegr;
430 model->task_param.grove_per_class =
false;
432 model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
433 model->task_param.num_class = num_class_;
434 model->task_param.leaf_vector_size = 1;
437 if (obj_name_ ==
"multiclass") {
441 for (
const auto& str : obj_param_) {
442 auto tokens = Split(str,
':');
443 if (tokens.size() == 2 && tokens[0] ==
"num_class" 444 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
449 TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class)
450 <<
"Ill-formed LightGBM model file: not a valid multiclass objective";
452 std::strncpy(model->param.pred_transform,
"softmax",
sizeof(model->param.pred_transform));
453 }
else if (obj_name_ ==
"multiclassova") {
459 for (
const auto& str : obj_param_) {
460 auto tokens = Split(str,
':');
461 if (tokens.size() == 2) {
462 if (tokens[0] ==
"num_class" 463 && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
465 }
else if (tokens[0] ==
"sigmoid" 466 && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
471 TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class
473 <<
"Ill-formed LightGBM model file: not a valid multiclassova objective";
475 std::strncpy(model->param.pred_transform,
"multiclass_ova",
476 sizeof(model->param.pred_transform));
477 model->param.sigmoid_alpha = alpha;
478 }
else if (obj_name_ ==
"binary") {
482 for (
const auto& str : obj_param_) {
483 auto tokens = Split(str,
':');
484 if (tokens.size() == 2 && tokens[0] ==
"sigmoid" 485 && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
490 TREELITE_CHECK_GT(alpha, 0.0f)
491 <<
"Ill-formed LightGBM model file: not a valid binary objective";
493 std::strncpy(model->param.pred_transform,
"sigmoid",
sizeof(model->param.pred_transform));
494 model->param.sigmoid_alpha = alpha;
495 }
else if (obj_name_ ==
"xentropy" || obj_name_ ==
"cross_entropy") {
496 std::strncpy(model->param.pred_transform,
"sigmoid",
sizeof(model->param.pred_transform));
497 model->param.sigmoid_alpha = 1.0f;
498 }
else if (obj_name_ ==
"xentlambda" || obj_name_ ==
"cross_entropy_lambda") {
499 std::strncpy(model->param.pred_transform,
"logarithm_one_plus_exp",
500 sizeof(model->param.pred_transform));
501 }
else if (obj_name_ ==
"poisson" || obj_name_ ==
"gamma" || obj_name_ ==
"tweedie") {
502 std::strncpy(model->param.pred_transform,
"exponential",
503 sizeof(model->param.pred_transform));
504 }
else if (obj_name_ ==
"regression" || obj_name_ ==
"regression_l1" || obj_name_ ==
"huber" 505 || obj_name_ ==
"fair" || obj_name_ ==
"quantile" || obj_name_ ==
"mape") {
507 bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(),
"sqrt") != obj_param_.cend());
509 std::strncpy(model->param.pred_transform,
"signed_square",
510 sizeof(model->param.pred_transform));
512 std::strncpy(model->param.pred_transform,
"identity",
513 sizeof(model->param.pred_transform));
515 }
else if (obj_name_ ==
"lambdarank" || obj_name_ ==
"rank_xendcg" || obj_name_ ==
"custom") {
517 std::strncpy(model->param.pred_transform,
"identity",
518 sizeof(model->param.pred_transform));
520 TREELITE_LOG(FATAL) <<
"Unrecognized objective: " << obj_name_;
524 for (
const auto& lgb_tree : lgb_trees_) {
525 model->trees.emplace_back();
531 std::queue<std::pair<int, int>> Q;
532 if (lgb_tree.num_leaves == 0) {
534 }
else if (lgb_tree.num_leaves == 1) {
542 std::tie(old_id, new_id) = Q.front(); Q.pop();
544 const double leaf_value = lgb_tree.leaf_value[~old_id];
545 tree.
SetLeaf(new_id, static_cast<double>(leaf_value));
546 if (!lgb_tree.leaf_count.empty()) {
547 const int data_count = lgb_tree.leaf_count[~old_id];
548 TREELITE_CHECK_GE(data_count, 0);
549 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
552 const auto split_index =
static_cast<unsigned>(lgb_tree.split_feature[old_id]);
553 const auto missing_type = GetMissingType(lgb_tree.decision_type[old_id]);
556 if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
558 const int cat_idx =
static_cast<int>(lgb_tree.threshold[old_id]);
559 const std::vector<uint32_t> left_categories
560 = BitsetToList(lgb_tree.cat_threshold.data()
561 + lgb_tree.cat_boundaries[cat_idx],
562 lgb_tree.cat_boundaries[cat_idx + 1]
563 - lgb_tree.cat_boundaries[cat_idx]);
566 bool default_left =
false;
570 const auto threshold =
static_cast<double>(lgb_tree.threshold[old_id]);
572 = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
574 const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
575 if (missing_value_to_zero) {
578 default_left = 0.0 <= threshold;
582 if (!lgb_tree.internal_count.empty()) {
583 const int data_count = lgb_tree.internal_count[old_id];
584 TREELITE_CHECK_GE(data_count, 0);
585 tree.
SetDataCount(new_id, static_cast<size_t>(data_count));
587 if (!lgb_tree.split_gain.empty()) {
588 tree.
SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
590 Q.push({lgb_tree.left_child[old_id], tree.
LeftChild(new_id)});
591 Q.push({lgb_tree.right_child[old_id], tree.
RightChild(new_id)});
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
model structure for tree ensemble
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
in-memory representation of a decision tree
logging facility for Treelite
std::unique_ptr< treelite::Model > LoadLightGBMModelFromString(const char *model_str)
Load a LightGBM model from a string. The string should be created with the model_to_string() method i...
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
int LeftChild(int nid) const
Getters.
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Operator
comparison operators
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node