treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <unordered_map>
12 #include <queue>
13 
14 namespace {
15 
16 treelite::Model ParseStream(dmlc::Stream* fi);
17 
18 } // anonymous namespace
19 
20 namespace treelite {
21 namespace frontend {
22 
23 DMLC_REGISTRY_FILE_TAG(lightgbm);
24 
25 Model LoadLightGBMModel(const char* filename) {
26  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
27  return ParseStream(fi.get());
28 }
29 
30 } // namespace frontend
31 } // namespace treelite
32 
33 /* auxiliary data structures to interpret lightgbm model file */
34 namespace {
35 
36 enum Masks : uint8_t {
37  kCategoricalMask = 1,
38  kDefaultLeftMask = 2
39 };
40 
41 enum class MissingType : uint8_t {
42  kNone,
43  kZero,
44  kNaN
45 };
46 
47 struct LGBTree {
48  int num_leaves;
49  int num_cat; // number of categorical splits
50  std::vector<double> leaf_value;
51  std::vector<int8_t> decision_type;
52  std::vector<uint64_t> cat_boundaries;
53  std::vector<uint32_t> cat_threshold;
54  std::vector<int> split_feature;
55  std::vector<double> threshold;
56  std::vector<int> left_child;
57  std::vector<int> right_child;
58  std::vector<float> split_gain;
59  std::vector<int> internal_count;
60  std::vector<int> leaf_count;
61 };
62 
63 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
64  return (decision_type & mask) > 0;
65 }
66 
67 inline MissingType GetMissingType(int8_t decision_type) {
68  return static_cast<MissingType>((decision_type >> 2) & 3);
69 }
70 
71 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
72  size_t nslots) {
73  std::vector<uint32_t> result;
74  const size_t nbits = nslots * 32;
75  for (size_t i = 0; i < nbits; ++i) {
76  const size_t i1 = i / 32;
77  const uint32_t i2 = static_cast<uint32_t>(i % 32);
78  if ((bits[i1] >> i2) & 1) {
79  result.push_back(static_cast<uint32_t>(i));
80  }
81  }
82  return result;
83 }
84 
85 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
86  const size_t bufsize = 16 * 1024 * 1024; // 16 MB
87  std::vector<char> buf(bufsize);
88 
89  std::vector<std::string> lines;
90 
91  size_t byte_read;
92 
93  std::string leftover = ""; // carry over between buffers
94  while ((byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) {
95  size_t i = 0;
96  size_t tok_begin = 0;
97  while (i < byte_read) {
98  if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines
99  if (tok_begin == 0 && leftover.length() + i > 0) {
100  // first line in buffer
101  lines.push_back(leftover + std::string(&buf[0], i));
102  leftover = "";
103  } else {
104  lines.emplace_back(&buf[tok_begin], i - tok_begin);
105  }
106  // skip all delimiters afterwards
107  for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i) {}
108  tok_begin = i;
109  } else {
110  ++i;
111  }
112  }
113  // left-over string
114  leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
115  }
116 
117  if (!leftover.empty()) {
118  LOG(INFO)
119  << "Warning: input file was not terminated with end-of-line character.";
120  lines.push_back(leftover);
121  }
122 
123  return lines;
124 }
125 
126 inline treelite::Model ParseStream(dmlc::Stream* fi) {
127  std::vector<LGBTree> lgb_trees_;
128  int max_feature_idx_;
129  int num_tree_per_iteration_;
130  bool average_output_;
131  std::string obj_name_;
132  std::vector<std::string> obj_param_;
133 
134  /* 1. Parse input stream */
135  std::vector<std::string> lines = LoadText(fi);
136  std::unordered_map<std::string, std::string> global_dict;
137  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
138 
139  bool in_tree = false; // is current entry part of a tree?
140  for (const auto& line : lines) {
141  std::istringstream ss(line);
142  std::string key, value, rest;
143  std::getline(ss, key, '=');
144  std::getline(ss, value, '=');
145  std::getline(ss, rest);
146  CHECK(rest.empty()) << "Ill-formed LightGBM model file";
147  if (key == "Tree") {
148  in_tree = true;
149  tree_dict.emplace_back();
150  } else {
151  if (in_tree) {
152  tree_dict.back()[key] = value;
153  } else {
154  global_dict[key] = value;
155  }
156  }
157  }
158 
159  {
160  auto it = global_dict.find("objective");
161  CHECK(it != global_dict.end())
162  << "Ill-formed LightGBM model file: need objective";
163  auto obj_strs = treelite::common::Split(it->second, ' ');
164  obj_name_ = obj_strs[0];
165  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
166 
167  it = global_dict.find("max_feature_idx");
168  CHECK(it != global_dict.end())
169  << "Ill-formed LightGBM model file: need max_feature_idx";
170  max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
171  it = global_dict.find("num_tree_per_iteration");
172  CHECK(it != global_dict.end())
173  << "Ill-formed LightGBM model file: need num_tree_per_iteration";
174  num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
175 
176  it = global_dict.find("average_output");
177  average_output_ = (it != global_dict.end());
178  }
179 
180  for (const auto& dict : tree_dict) {
181  lgb_trees_.emplace_back();
182  LGBTree& tree = lgb_trees_.back();
183 
184  auto it = dict.find("num_leaves");
185  CHECK(it != dict.end())
186  << "Ill-formed LightGBM model file: need num_leaves";
187  tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
188 
189  it = dict.find("num_cat");
190  CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
191  tree.num_cat = treelite::common::TextToNumber<int>(it->second);
192 
193  it = dict.find("leaf_value");
194  CHECK(it != dict.end() && !it->second.empty())
195  << "Ill-formed LightGBM model file: need leaf_value";
196  tree.leaf_value
197  = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
198 
199  it = dict.find("decision_type");
200  if (it == dict.end()) {
201  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
202  } else {
203  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
204  << "Ill-formed LightGBM model file: decision_type cannot be empty string";
205  tree.decision_type
206  = treelite::common::TextToArray<int8_t>(it->second,
207  tree.num_leaves - 1);
208  }
209 
210  if (tree.num_cat > 0) {
211  it = dict.find("cat_boundaries");
212  CHECK(it != dict.end() && !it->second.empty())
213  << "Ill-formed LightGBM model file: need cat_boundaries";
214  tree.cat_boundaries
215  = treelite::common::TextToArray<uint64_t>(it->second, tree.num_cat + 1);
216  it = dict.find("cat_threshold");
217  CHECK(it != dict.end() && !it->second.empty())
218  << "Ill-formed LightGBM model file: need cat_threshold";
219  tree.cat_threshold
220  = treelite::common::TextToArray<uint32_t>(it->second,
221  tree.cat_boundaries.back());
222  }
223 
224  it = dict.find("split_feature");
225  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
226  << "Ill-formed LightGBM model file: need split_feature";
227  tree.split_feature
228  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
229 
230  it = dict.find("threshold");
231  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
232  << "Ill-formed LightGBM model file: need threshold";
233  tree.threshold
234  = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
235 
236  it = dict.find("split_gain");
237  if (it != dict.end()) {
238  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
239  << "Ill-formed LightGBM model file: split_gain cannot be empty string";
240  tree.split_gain
241  = treelite::common::TextToArray<float>(it->second, tree.num_leaves - 1);
242  } else {
243  tree.split_gain.resize(tree.num_leaves - 1);
244  }
245 
246  it = dict.find("internal_count");
247  if (it != dict.end()) {
248  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
249  << "Ill-formed LightGBM model file: internal_count cannot be empty string";
250  tree.internal_count
251  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
252  } else {
253  tree.internal_count.resize(tree.num_leaves - 1);
254  }
255 
256  it = dict.find("leaf_count");
257  if (it != dict.end()) {
258  CHECK(!it->second.empty())
259  << "Ill-formed LightGBM model file: leaf_count cannot be empty string";
260  tree.leaf_count
261  = treelite::common::TextToArray<int>(it->second, tree.num_leaves);
262  } else {
263  tree.leaf_count.resize(tree.num_leaves);
264  }
265 
266  it = dict.find("left_child");
267  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
268  << "Ill-formed LightGBM model file: need left_child";
269  tree.left_child
270  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
271 
272  it = dict.find("right_child");
273  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
274  << "Ill-formed LightGBM model file: need right_child";
275  tree.right_child
276  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
277  }
278 
279  /* 2. Export model */
280  treelite::Model model;
281  model.num_feature = max_feature_idx_ + 1;
282  model.num_output_group = num_tree_per_iteration_;
283  if (model.num_output_group > 1) {
284  // multiclass classification with gradient boosted trees
285  CHECK(!average_output_)
286  << "Ill-formed LightGBM model file: cannot use random forest mode "
287  << "for multi-class classification";
288  model.random_forest_flag = false;
289  } else {
290  model.random_forest_flag = average_output_;
291  }
292 
293  // set correct prediction transform function, depending on objective function
294  if (obj_name_ == "multiclass") {
295  // validate num_class parameter
296  int num_class = -1;
297  int tmp;
298  for (const auto& str : obj_param_) {
299  auto tokens = treelite::common::Split(str, ':');
300  if (tokens.size() == 2 && tokens[0] == "num_class"
301  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
302  num_class = tmp;
303  break;
304  }
305  }
306  CHECK(num_class >= 0 && num_class == model.num_output_group)
307  << "Ill-formed LightGBM model file: not a valid multiclass objective";
308 
309  model.param.pred_transform = "softmax";
310  } else if (obj_name_ == "multiclassova") {
311  // validate num_class and alpha parameters
312  int num_class = -1;
313  float alpha = -1.0f;
314  int tmp;
315  float tmp2;
316  for (const auto& str : obj_param_) {
317  auto tokens = treelite::common::Split(str, ':');
318  if (tokens.size() == 2) {
319  if (tokens[0] == "num_class"
320  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
321  num_class = tmp;
322  } else if (tokens[0] == "sigmoid"
323  && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
324  alpha = tmp2;
325  }
326  }
327  }
328  CHECK(num_class >= 0 && num_class == model.num_output_group
329  && alpha > 0.0f)
330  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
331 
332  model.param.pred_transform = "multiclass_ova";
333  model.param.sigmoid_alpha = alpha;
334  } else if (obj_name_ == "binary") {
335  // validate alpha parameter
336  float alpha = -1.0f;
337  float tmp;
338  for (const auto& str : obj_param_) {
339  auto tokens = treelite::common::Split(str, ':');
340  if (tokens.size() == 2 && tokens[0] == "sigmoid"
341  && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
342  alpha = tmp;
343  break;
344  }
345  }
346  CHECK_GT(alpha, 0.0f)
347  << "Ill-formed LightGBM model file: not a valid binary objective";
348 
349  model.param.pred_transform = "sigmoid";
350  model.param.sigmoid_alpha = alpha;
351  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
352  model.param.pred_transform = "sigmoid";
353  model.param.sigmoid_alpha = 1.0f;
354  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
355  model.param.pred_transform = "logarithm_one_plus_exp";
356  } else {
357  model.param.pred_transform = "identity";
358  }
359 
360  // traverse trees
361  for (const auto& lgb_tree : lgb_trees_) {
362  model.trees.emplace_back();
363  treelite::Tree& tree = model.trees.back();
364  tree.Init();
365 
366  // assign node ID's so that a breadth-wise traversal would yield
367  // the monotonic sequence 0, 1, 2, ...
368  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
369  Q.push({0, 0});
370  while (!Q.empty()) {
371  int old_id, new_id;
372  std::tie(old_id, new_id) = Q.front(); Q.pop();
373  if (old_id < 0) { // leaf
374  const double leaf_value = lgb_tree.leaf_value[~old_id];
375  const int data_count = lgb_tree.leaf_count[~old_id];
376  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
377  CHECK_GE(data_count, 0);
378  tree[new_id].set_data_count(static_cast<size_t>(data_count));
379  } else { // non-leaf
380  const int data_count = lgb_tree.internal_count[old_id];
381  const unsigned split_index =
382  static_cast<unsigned>(lgb_tree.split_feature[old_id]);
383 
384  tree.AddChilds(new_id);
385  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
386  // categorical
387  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
388  const std::vector<uint32_t> left_categories
389  = BitsetToList(lgb_tree.cat_threshold.data()
390  + lgb_tree.cat_boundaries[cat_idx],
391  lgb_tree.cat_boundaries[cat_idx + 1]
392  - lgb_tree.cat_boundaries[cat_idx]);
393  const auto missing_type
394  = GetMissingType(lgb_tree.decision_type[old_id]);
395  tree[new_id].set_categorical_split(split_index, false,
396  (missing_type != MissingType::kNaN),
397  left_categories);
398  } else {
399  // numerical
400  const treelite::tl_float threshold =
401  static_cast<treelite::tl_float>(lgb_tree.threshold[old_id]);
402  const bool default_left
403  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
404  const treelite::Operator cmp_op = treelite::Operator::kLE;
405  tree[new_id].set_numerical_split(split_index, threshold,
406  default_left, cmp_op);
407  }
408  CHECK_GE(data_count, 0);
409  tree[new_id].set_data_count(static_cast<size_t>(data_count));
410  tree[new_id].set_gain(static_cast<double>(lgb_tree.split_gain[old_id]));
411  Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
412  Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
413  }
414  }
415  }
416  LOG(INFO) << "model.num_tree = " << model.trees.size();
417  return model;
418 }
419 
420 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree.h:311
thin wrapper for tree ensemble model
Definition: tree.h:427
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:22
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:25
double tl_float
float type to be used internally
Definition: base.h:17
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:321
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:434
Operator
comparison operators
Definition: base.h:23