Treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <unordered_map>
12 #include <limits>
13 #include <queue>
14 
15 namespace {
16 
17 treelite::Model ParseStream(dmlc::Stream* fi);
18 
19 } // anonymous namespace
20 
21 namespace treelite {
22 namespace frontend {
23 
24 DMLC_REGISTRY_FILE_TAG(lightgbm);
25 
26 void LoadLightGBMModel(const char *filename, Model* out) {
27  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
28  *out = std::move(ParseStream(fi.get()));
29 }
30 
31 } // namespace frontend
32 } // namespace treelite
33 
34 /* auxiliary data structures to interpret lightgbm model file */
35 namespace {
36 
37 template <typename T>
38 inline T TextToNumber(const std::string& str) {
39  static_assert(std::is_same<T, float>::value
40  || std::is_same<T, double>::value
41  || std::is_same<T, int>::value
42  || std::is_same<T, int8_t>::value
43  || std::is_same<T, uint32_t>::value
44  || std::is_same<T, uint64_t>::value,
45  "unsupported data type for TextToNumber; use float, double, "
46  "int, int8_t, uint32_t, or uint64_t");
47 }
48 
49 template <>
50 inline float TextToNumber(const std::string& str) {
51  errno = 0;
52  char *endptr;
53  float val = std::strtof(str.c_str(), &endptr);
54  if (errno == ERANGE) {
55  LOG(FATAL) << "Range error while converting string to double";
56  } else if (errno != 0) {
57  LOG(FATAL) << "Unknown error";
58  } else if (*endptr != '\0') {
59  LOG(FATAL) << "String does not represent a valid floating-point number";
60  }
61  return val;
62 }
63 
64 template <>
65 inline double TextToNumber(const std::string& str) {
66  errno = 0;
67  char *endptr;
68  double val = std::strtod(str.c_str(), &endptr);
69  if (errno == ERANGE) {
70  LOG(FATAL) << "Range error while converting string to double";
71  } else if (errno != 0) {
72  LOG(FATAL) << "Unknown error";
73  } else if (*endptr != '\0') {
74  LOG(FATAL) << "String does not represent a valid floating-point number";
75  }
76  return val;
77 }
78 
79 template <>
80 inline int TextToNumber(const std::string& str) {
81  errno = 0;
82  char *endptr;
83  auto val = std::strtol(str.c_str(), &endptr, 10);
84  if (errno == ERANGE || val < std::numeric_limits<int>::min()
85  || val > std::numeric_limits<int>::max()) {
86  LOG(FATAL) << "Range error while converting string to int";
87  } else if (errno != 0) {
88  LOG(FATAL) << "Unknown error";
89  } else if (*endptr != '\0') {
90  LOG(FATAL) << "String does not represent a valid integer";
91  }
92  return static_cast<int>(val);
93 }
94 
95 template <>
96 inline int8_t TextToNumber(const std::string& str) {
97  errno = 0;
98  char *endptr;
99  auto val = std::strtol(str.c_str(), &endptr, 10);
100  if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
101  || val > std::numeric_limits<int8_t>::max()) {
102  LOG(FATAL) << "Range error while converting string to int8_t";
103  } else if (errno != 0) {
104  LOG(FATAL) << "Unknown error";
105  } else if (*endptr != '\0') {
106  LOG(FATAL) << "String does not represent a valid integer";
107  }
108  return static_cast<int8_t>(val);
109 }
110 
111 template <>
112 inline uint32_t TextToNumber(const std::string& str) {
113  errno = 0;
114  char *endptr;
115  auto val = std::strtoul(str.c_str(), &endptr, 10);
116  if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
117  LOG(FATAL) << "Range error while converting string to uint32_t";
118  } else if (errno != 0) {
119  LOG(FATAL) << "Unknown error";
120  } else if (*endptr != '\0') {
121  LOG(FATAL) << "String does not represent a valid integer";
122  }
123  return static_cast<uint32_t>(val);
124 }
125 
126 template <>
127 inline uint64_t TextToNumber(const std::string& str) {
128  errno = 0;
129  char *endptr;
130  auto val = std::strtoull(str.c_str(), &endptr, 10);
131  if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
132  LOG(FATAL) << "Range error while converting string to uint64_t";
133  } else if (errno != 0) {
134  LOG(FATAL) << "Unknown error";
135  } else if (*endptr != '\0') {
136  LOG(FATAL) << "String does not represent a valid integer";
137  }
138  return static_cast<uint64_t>(val);
139 }
140 
141 inline std::vector<std::string> Split(const std::string& text, char delim) {
142  std::vector<std::string> array;
143  std::istringstream ss(text);
144  std::string token;
145  while (std::getline(ss, token, delim)) {
146  array.push_back(token);
147  }
148  return array;
149 }
150 
151 template <typename T>
152 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
153  if (text.empty() && num_entry > 0) {
154  LOG(FATAL) << "Cannot convert empty text into array";
155  }
156  std::vector<T> array;
157  std::istringstream ss(text);
158  std::string token;
159  for (int i = 0; i < num_entry; ++i) {
160  std::getline(ss, token, ' ');
161  array.push_back(TextToNumber<T>(token));
162  }
163  return array;
164 }
165 
166 enum Masks : uint8_t {
167  kCategoricalMask = 1,
168  kDefaultLeftMask = 2
169 };
170 
171 enum class MissingType : uint8_t {
172  kNone,
173  kZero,
174  kNaN
175 };
176 
177 struct LGBTree {
178  int num_leaves;
179  int num_cat; // number of categorical splits
180  std::vector<double> leaf_value;
181  std::vector<int8_t> decision_type;
182  std::vector<uint64_t> cat_boundaries;
183  std::vector<uint32_t> cat_threshold;
184  std::vector<int> split_feature;
185  std::vector<double> threshold;
186  std::vector<int> left_child;
187  std::vector<int> right_child;
188  std::vector<float> split_gain;
189  std::vector<int> internal_count;
190  std::vector<int> leaf_count;
191 };
192 
193 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
194  return (decision_type & mask) > 0;
195 }
196 
197 inline MissingType GetMissingType(int8_t decision_type) {
198  return static_cast<MissingType>((decision_type >> 2) & 3);
199 }
200 
201 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
202  size_t nslots) {
203  std::vector<uint32_t> result;
204  const size_t nbits = nslots * 32;
205  for (size_t i = 0; i < nbits; ++i) {
206  const size_t i1 = i / 32;
207  const uint32_t i2 = static_cast<uint32_t>(i % 32);
208  if ((bits[i1] >> i2) & 1) {
209  result.push_back(static_cast<uint32_t>(i));
210  }
211  }
212  return result;
213 }
214 
215 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
216  const size_t bufsize = 16 * 1024 * 1024; // 16 MB
217  std::vector<char> buf(bufsize);
218 
219  std::vector<std::string> lines;
220 
221  size_t byte_read;
222 
223  std::string leftover = ""; // carry over between buffers
224  while ((byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) {
225  size_t i = 0;
226  size_t tok_begin = 0;
227  while (i < byte_read) {
228  if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines
229  if (tok_begin == 0 && leftover.length() + i > 0) {
230  // first line in buffer
231  lines.push_back(leftover + std::string(&buf[0], i));
232  leftover = "";
233  } else {
234  lines.emplace_back(&buf[tok_begin], i - tok_begin);
235  }
236  // skip all delimiters afterwards
237  for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i) {}
238  tok_begin = i;
239  } else {
240  ++i;
241  }
242  }
243  // left-over string
244  leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
245  }
246 
247  if (!leftover.empty()) {
248  LOG(INFO)
249  << "Warning: input file was not terminated with end-of-line character.";
250  lines.push_back(leftover);
251  }
252 
253  return lines;
254 }
255 
256 inline treelite::Model ParseStream(dmlc::Stream* fi) {
257  std::vector<LGBTree> lgb_trees_;
258  int max_feature_idx_;
259  int num_tree_per_iteration_;
260  bool average_output_;
261  std::string obj_name_;
262  std::vector<std::string> obj_param_;
263 
264  /* 1. Parse input stream */
265  std::vector<std::string> lines = LoadText(fi);
266  std::unordered_map<std::string, std::string> global_dict;
267  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
268 
269  bool in_tree = false; // is current entry part of a tree?
270  for (const auto& line : lines) {
271  std::istringstream ss(line);
272  std::string key, value, rest;
273  std::getline(ss, key, '=');
274  std::getline(ss, value, '=');
275  std::getline(ss, rest);
276  CHECK(rest.empty()) << "Ill-formed LightGBM model file";
277  if (key == "Tree") {
278  in_tree = true;
279  tree_dict.emplace_back();
280  } else {
281  if (in_tree) {
282  tree_dict.back()[key] = value;
283  } else {
284  global_dict[key] = value;
285  }
286  }
287  }
288 
289  {
290  auto it = global_dict.find("objective");
291  CHECK(it != global_dict.end())
292  << "Ill-formed LightGBM model file: need objective";
293  auto obj_strs = Split(it->second, ' ');
294  obj_name_ = obj_strs[0];
295  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
296 
297  it = global_dict.find("max_feature_idx");
298  CHECK(it != global_dict.end())
299  << "Ill-formed LightGBM model file: need max_feature_idx";
300  max_feature_idx_ = TextToNumber<int>(it->second);
301  it = global_dict.find("num_tree_per_iteration");
302  CHECK(it != global_dict.end())
303  << "Ill-formed LightGBM model file: need num_tree_per_iteration";
304  num_tree_per_iteration_ = TextToNumber<int>(it->second);
305 
306  it = global_dict.find("average_output");
307  average_output_ = (it != global_dict.end());
308  }
309 
310  for (const auto& dict : tree_dict) {
311  lgb_trees_.emplace_back();
312  LGBTree& tree = lgb_trees_.back();
313 
314  auto it = dict.find("num_leaves");
315  CHECK(it != dict.end())
316  << "Ill-formed LightGBM model file: need num_leaves";
317  tree.num_leaves = TextToNumber<int>(it->second);
318 
319  it = dict.find("num_cat");
320  CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
321  tree.num_cat = TextToNumber<int>(it->second);
322 
323  it = dict.find("leaf_value");
324  CHECK(it != dict.end() && !it->second.empty())
325  << "Ill-formed LightGBM model file: need leaf_value";
326  tree.leaf_value
327  = TextToArray<double>(it->second, tree.num_leaves);
328 
329  it = dict.find("decision_type");
330  if (it == dict.end()) {
331  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
332  } else {
333  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
334  << "Ill-formed LightGBM model file: decision_type cannot be empty string";
335  tree.decision_type
336  = TextToArray<int8_t>(it->second,
337  tree.num_leaves - 1);
338  }
339 
340  if (tree.num_cat > 0) {
341  it = dict.find("cat_boundaries");
342  CHECK(it != dict.end() && !it->second.empty())
343  << "Ill-formed LightGBM model file: need cat_boundaries";
344  tree.cat_boundaries
345  = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
346  it = dict.find("cat_threshold");
347  CHECK(it != dict.end() && !it->second.empty())
348  << "Ill-formed LightGBM model file: need cat_threshold";
349  tree.cat_threshold
350  = TextToArray<uint32_t>(it->second,
351  tree.cat_boundaries.back());
352  }
353 
354  it = dict.find("split_feature");
355  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
356  << "Ill-formed LightGBM model file: need split_feature";
357  tree.split_feature
358  = TextToArray<int>(it->second, tree.num_leaves - 1);
359 
360  it = dict.find("threshold");
361  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
362  << "Ill-formed LightGBM model file: need threshold";
363  tree.threshold
364  = TextToArray<double>(it->second, tree.num_leaves - 1);
365 
366  it = dict.find("split_gain");
367  if (it != dict.end()) {
368  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
369  << "Ill-formed LightGBM model file: split_gain cannot be empty string";
370  tree.split_gain
371  = TextToArray<float>(it->second, tree.num_leaves - 1);
372  } else {
373  tree.split_gain.resize(tree.num_leaves - 1);
374  }
375 
376  it = dict.find("internal_count");
377  if (it != dict.end()) {
378  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
379  << "Ill-formed LightGBM model file: internal_count cannot be empty string";
380  tree.internal_count
381  = TextToArray<int>(it->second, tree.num_leaves - 1);
382  } else {
383  tree.internal_count.resize(tree.num_leaves - 1);
384  }
385 
386  it = dict.find("leaf_count");
387  if (it != dict.end()) {
388  CHECK(!it->second.empty())
389  << "Ill-formed LightGBM model file: leaf_count cannot be empty string";
390  tree.leaf_count
391  = TextToArray<int>(it->second, tree.num_leaves);
392  } else {
393  tree.leaf_count.resize(tree.num_leaves);
394  }
395 
396  it = dict.find("left_child");
397  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
398  << "Ill-formed LightGBM model file: need left_child";
399  tree.left_child
400  = TextToArray<int>(it->second, tree.num_leaves - 1);
401 
402  it = dict.find("right_child");
403  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
404  << "Ill-formed LightGBM model file: need right_child";
405  tree.right_child
406  = TextToArray<int>(it->second, tree.num_leaves - 1);
407  }
408 
409  /* 2. Export model */
410  treelite::Model model;
411  model.num_feature = max_feature_idx_ + 1;
412  model.num_output_group = num_tree_per_iteration_;
413  if (model.num_output_group > 1) {
414  // multiclass classification with gradient boosted trees
415  CHECK(!average_output_)
416  << "Ill-formed LightGBM model file: cannot use random forest mode "
417  << "for multi-class classification";
418  model.random_forest_flag = false;
419  } else {
420  model.random_forest_flag = average_output_;
421  }
422 
423  // set correct prediction transform function, depending on objective function
424  if (obj_name_ == "multiclass") {
425  // validate num_class parameter
426  int num_class = -1;
427  int tmp;
428  for (const auto& str : obj_param_) {
429  auto tokens = Split(str, ':');
430  if (tokens.size() == 2 && tokens[0] == "num_class"
431  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
432  num_class = tmp;
433  break;
434  }
435  }
436  CHECK(num_class >= 0 && num_class == model.num_output_group)
437  << "Ill-formed LightGBM model file: not a valid multiclass objective";
438 
439  std::strncpy(model.param.pred_transform, "softmax", sizeof(model.param.pred_transform));
440  } else if (obj_name_ == "multiclassova") {
441  // validate num_class and alpha parameters
442  int num_class = -1;
443  float alpha = -1.0f;
444  int tmp;
445  float tmp2;
446  for (const auto& str : obj_param_) {
447  auto tokens = Split(str, ':');
448  if (tokens.size() == 2) {
449  if (tokens[0] == "num_class"
450  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
451  num_class = tmp;
452  } else if (tokens[0] == "sigmoid"
453  && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
454  alpha = tmp2;
455  }
456  }
457  }
458  CHECK(num_class >= 0 && num_class == model.num_output_group
459  && alpha > 0.0f)
460  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
461 
462  std::strncpy(model.param.pred_transform, "multiclass_ova", sizeof(model.param.pred_transform));
463  model.param.sigmoid_alpha = alpha;
464  } else if (obj_name_ == "binary") {
465  // validate alpha parameter
466  float alpha = -1.0f;
467  float tmp;
468  for (const auto& str : obj_param_) {
469  auto tokens = Split(str, ':');
470  if (tokens.size() == 2 && tokens[0] == "sigmoid"
471  && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
472  alpha = tmp;
473  break;
474  }
475  }
476  CHECK_GT(alpha, 0.0f)
477  << "Ill-formed LightGBM model file: not a valid binary objective";
478 
479  std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform));
480  model.param.sigmoid_alpha = alpha;
481  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
482  std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform));
483  model.param.sigmoid_alpha = 1.0f;
484  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
485  std::strncpy(model.param.pred_transform, "logarithm_one_plus_exp",
486  sizeof(model.param.pred_transform));
487  } else {
488  std::strncpy(model.param.pred_transform, "identity", sizeof(model.param.pred_transform));
489  }
490 
491  // traverse trees
492  for (const auto& lgb_tree : lgb_trees_) {
493  model.trees.emplace_back();
494  treelite::Tree& tree = model.trees.back();
495  tree.Init();
496 
497  // assign node ID's so that a breadth-wise traversal would yield
498  // the monotonic sequence 0, 1, 2, ...
499  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
500  Q.push({0, 0});
501  while (!Q.empty()) {
502  int old_id, new_id;
503  std::tie(old_id, new_id) = Q.front(); Q.pop();
504  if (old_id < 0) { // leaf
505  const double leaf_value = lgb_tree.leaf_value[~old_id];
506  const int data_count = lgb_tree.leaf_count[~old_id];
507  tree.SetLeaf(new_id, static_cast<treelite::tl_float>(leaf_value));
508  CHECK_GE(data_count, 0);
509  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
510  } else { // non-leaf
511  const int data_count = lgb_tree.internal_count[old_id];
512  const auto split_index =
513  static_cast<unsigned>(lgb_tree.split_feature[old_id]);
514 
515  tree.AddChilds(new_id);
516  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
517  // categorical
518  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
519  const std::vector<uint32_t> left_categories
520  = BitsetToList(lgb_tree.cat_threshold.data()
521  + lgb_tree.cat_boundaries[cat_idx],
522  lgb_tree.cat_boundaries[cat_idx + 1]
523  - lgb_tree.cat_boundaries[cat_idx]);
524  const auto missing_type
525  = GetMissingType(lgb_tree.decision_type[old_id]);
526  tree.SetCategoricalSplit(new_id, split_index, false, (missing_type != MissingType::kNaN),
527  left_categories);
528  } else {
529  // numerical
530  const auto threshold = static_cast<treelite::tl_float>(lgb_tree.threshold[old_id]);
531  const bool default_left
532  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
533  const treelite::Operator cmp_op = treelite::Operator::kLE;
534  tree.SetNumericalSplit(new_id, split_index, threshold, default_left, cmp_op);
535  }
536  CHECK_GE(data_count, 0);
537  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
538  tree.SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
539  Q.push({lgb_tree.left_child[old_id], tree.LeftChild(new_id)});
540  Q.push({lgb_tree.right_child[old_id], tree.RightChild(new_id)});
541  }
542  }
543  }
544  LOG(INFO) << "model.num_tree = " << model.trees.size();
545  return model;
546 }
547 
548 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:476
thin wrapper for tree ensemble model
Definition: tree.h:409
float tl_float
float type to be used internally
Definition: base.h:18
model structure for tree ensemble
in-memory representation of a decision tree
Definition: tree.h:80
void LoadLightGBMModel(const char *filename, Model *out)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:26
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
Definition: tree_impl.h:649
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree_impl.h:730
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree_impl.h:723
void SetLeaf(int nid, tl_float value)
set the leaf value of the node
Definition: tree_impl.h:682
void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:635
int LeftChild(int nid) const
Getters.
Definition: tree_impl.h:524
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree_impl.h:529
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:488
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:416
Operator
comparison operators
Definition: base.h:24