treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <unordered_map>
12 #include <queue>
13 
14 namespace {
15 
16 treelite::Model ParseStream(dmlc::Stream* fi);
17 
18 } // anonymous namespace
19 
20 namespace treelite {
21 namespace frontend {
22 
23 DMLC_REGISTRY_FILE_TAG(lightgbm);
24 
25 Model LoadLightGBMModel(const char* filename) {
26  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
27  return ParseStream(fi.get());
28 }
29 
30 } // namespace frontend
31 } // namespace treelite
32 
33 /* auxiliary data structures to interpret lightgbm model file */
34 namespace {
35 
36 enum Masks : uint8_t {
37  kCategoricalMask = 1,
38  kDefaultLeftMask = 2
39 };
40 
41 struct LGBTree {
42  int num_leaves;
43  int num_cat; // number of categorical splits
44  std::vector<double> leaf_value;
45  std::vector<int8_t> decision_type;
46  std::vector<int> cat_boundaries;
47  std::vector<uint32_t> cat_threshold;
48  std::vector<int> split_feature;
49  std::vector<double> threshold;
50  std::vector<int> left_child;
51  std::vector<int> right_child;
52  std::vector<float> split_gain;
53  std::vector<int> internal_count;
54  std::vector<int> leaf_count;
55 };
56 
57 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
58  return (decision_type & mask) > 0;
59 }
60 
61 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
62  uint8_t nslots) {
63  std::vector<uint32_t> result;
64  const uint32_t nbits = static_cast<uint32_t>(nslots) * 32;
65  for (uint32_t i = 0; i < nbits; ++i) {
66  const uint32_t i1 = i / 32;
67  const uint32_t i2 = i % 32;
68  if ((bits[i1] >> i2) & 1) {
69  result.push_back(i);
70  }
71  }
72  return result;
73 }
74 
75 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
76  const size_t bufsize = 16 * 1024 * 1024; // 16 MB
77  std::vector<char> buf(bufsize);
78 
79  std::vector<std::string> lines;
80 
81  size_t byte_read;
82 
83  std::string leftover = ""; // carry over between buffers
84  while ((byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) {
85  size_t i = 0;
86  size_t tok_begin = 0;
87  while (i < byte_read) {
88  if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines
89  if (tok_begin == 0 && leftover.length() + i > 0) {
90  // first line in buffer
91  lines.push_back(leftover + std::string(&buf[0], i));
92  leftover = "";
93  } else {
94  lines.emplace_back(&buf[tok_begin], i - tok_begin);
95  }
96  // skip all delimiters afterwards
97  for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i) {}
98  tok_begin = i;
99  } else {
100  ++i;
101  }
102  }
103  // left-over string
104  leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
105  }
106 
107  if (!leftover.empty()) {
108  LOG(INFO)
109  << "Warning: input file was not terminated with end-of-line character.";
110  lines.push_back(leftover);
111  }
112 
113  return lines;
114 }
115 
116 inline treelite::Model ParseStream(dmlc::Stream* fi) {
117  std::vector<LGBTree> lgb_trees_;
118  int max_feature_idx_;
119  int num_tree_per_iteration_;
120  bool average_output_;
121  std::string obj_name_;
122  std::vector<std::string> obj_param_;
123 
124  /* 1. Parse input stream */
125  std::vector<std::string> lines = LoadText(fi);
126  std::unordered_map<std::string, std::string> global_dict;
127  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
128 
129  bool in_tree = false; // is current entry part of a tree?
130  for (const auto& line : lines) {
131  std::istringstream ss(line);
132  std::string key, value, rest;
133  std::getline(ss, key, '=');
134  std::getline(ss, value, '=');
135  std::getline(ss, rest);
136  CHECK(rest.empty()) << "Ill-formed LightGBM model file";
137  if (key == "Tree") {
138  in_tree = true;
139  tree_dict.emplace_back();
140  } else {
141  if (in_tree) {
142  tree_dict.back()[key] = value;
143  } else {
144  global_dict[key] = value;
145  }
146  }
147  }
148 
149  {
150  auto it = global_dict.find("objective");
151  CHECK(it != global_dict.end())
152  << "Ill-formed LightGBM model file: need objective";
153  auto obj_strs = treelite::common::Split(it->second, ' ');
154  obj_name_ = obj_strs[0];
155  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
156 
157  it = global_dict.find("max_feature_idx");
158  CHECK(it != global_dict.end())
159  << "Ill-formed LightGBM model file: need max_feature_idx";
160  max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
161  it = global_dict.find("num_tree_per_iteration");
162  CHECK(it != global_dict.end())
163  << "Ill-formed LightGBM model file: need num_tree_per_iteration";
164  num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
165 
166  it = global_dict.find("average_output");
167  average_output_ = (it != global_dict.end());
168  }
169 
170  for (const auto& dict : tree_dict) {
171  lgb_trees_.emplace_back();
172  LGBTree& tree = lgb_trees_.back();
173 
174  auto it = dict.find("num_leaves");
175  CHECK(it != dict.end())
176  << "Ill-formed LightGBM model file: need num_leaves";
177  tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
178 
179  it = dict.find("num_cat");
180  CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
181  tree.num_cat = treelite::common::TextToNumber<int>(it->second);
182 
183  it = dict.find("leaf_value");
184  CHECK(it != dict.end() && !it->second.empty())
185  << "Ill-formed LightGBM model file: need leaf_value";
186  tree.leaf_value
187  = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
188 
189  it = dict.find("decision_type");
190  if (it == dict.end()) {
191  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
192  } else {
193  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
194  << "Ill-formed LightGBM model file: decision_type cannot be empty string";
195  tree.decision_type
196  = treelite::common::TextToArray<int8_t>(it->second,
197  tree.num_leaves - 1);
198  }
199 
200  if (tree.num_cat > 0) {
201  it = dict.find("cat_boundaries");
202  CHECK(it != dict.end() && !it->second.empty())
203  << "Ill-formed LightGBM model file: need cat_boundaries";
204  tree.cat_boundaries
205  = treelite::common::TextToArray<int>(it->second, tree.num_cat + 1);
206  it = dict.find("cat_threshold");
207  CHECK(it != dict.end() && !it->second.empty())
208  << "Ill-formed LightGBM model file: need cat_threshold";
209  tree.cat_threshold
210  = treelite::common::TextToArray<uint32_t>(it->second,
211  tree.cat_boundaries.back());
212  }
213 
214  it = dict.find("split_feature");
215  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
216  << "Ill-formed LightGBM model file: need split_feature";
217  tree.split_feature
218  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
219 
220  it = dict.find("threshold");
221  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
222  << "Ill-formed LightGBM model file: need threshold";
223  tree.threshold
224  = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
225 
226  it = dict.find("split_gain");
227  if (it != dict.end()) {
228  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
229  << "Ill-formed LightGBM model file: split_gain cannot be empty string";
230  tree.split_gain
231  = treelite::common::TextToArray<float>(it->second, tree.num_leaves - 1);
232  } else {
233  tree.split_gain.resize(tree.num_leaves - 1);
234  }
235 
236  it = dict.find("internal_count");
237  if (it != dict.end()) {
238  CHECK(tree.num_leaves - 1 == 0 || !it->second.empty())
239  << "Ill-formed LightGBM model file: internal_count cannot be empty string";
240  tree.internal_count
241  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
242  } else {
243  tree.internal_count.resize(tree.num_leaves - 1);
244  }
245 
246  it = dict.find("leaf_count");
247  if (it != dict.end()) {
248  CHECK(!it->second.empty())
249  << "Ill-formed LightGBM model file: leaf_count cannot be empty string";
250  tree.leaf_count
251  = treelite::common::TextToArray<int>(it->second, tree.num_leaves);
252  } else {
253  tree.leaf_count.resize(tree.num_leaves);
254  }
255 
256  it = dict.find("left_child");
257  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
258  << "Ill-formed LightGBM model file: need left_child";
259  tree.left_child
260  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
261 
262  it = dict.find("right_child");
263  CHECK(it != dict.end() && (tree.num_leaves - 1 == 0 || !it->second.empty()))
264  << "Ill-formed LightGBM model file: need right_child";
265  tree.right_child
266  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
267  }
268 
269  /* 2. Export model */
270  treelite::Model model;
271  model.num_feature = max_feature_idx_ + 1;
272  model.num_output_group = num_tree_per_iteration_;
273  if (model.num_output_group > 1) {
274  // multiclass classification with gradient boosted trees
275  CHECK(!average_output_)
276  << "Ill-formed LightGBM model file: cannot use random forest mode "
277  << "for multi-class classification";
278  model.random_forest_flag = false;
279  } else {
280  model.random_forest_flag = average_output_;
281  }
282 
283  // set correct prediction transform function, depending on objective function
284  if (obj_name_ == "multiclass") {
285  // validate num_class parameter
286  int num_class = -1;
287  int tmp;
288  for (const auto& str : obj_param_) {
289  auto tokens = treelite::common::Split(str, ':');
290  if (tokens.size() == 2 && tokens[0] == "num_class"
291  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
292  num_class = tmp;
293  break;
294  }
295  }
296  CHECK(num_class >= 0 && num_class == model.num_output_group)
297  << "Ill-formed LightGBM model file: not a valid multiclass objective";
298 
299  model.param.pred_transform = "softmax";
300  } else if (obj_name_ == "multiclassova") {
301  // validate num_class and alpha parameters
302  int num_class = -1;
303  float alpha = -1.0f;
304  int tmp;
305  float tmp2;
306  for (const auto& str : obj_param_) {
307  auto tokens = treelite::common::Split(str, ':');
308  if (tokens.size() == 2) {
309  if (tokens[0] == "num_class"
310  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
311  num_class = tmp;
312  } else if (tokens[0] == "sigmoid"
313  && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
314  alpha = tmp2;
315  }
316  }
317  }
318  CHECK(num_class >= 0 && num_class == model.num_output_group
319  && alpha > 0.0f)
320  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
321 
322  model.param.pred_transform = "multiclass_ova";
323  model.param.sigmoid_alpha = alpha;
324  } else if (obj_name_ == "binary") {
325  // validate alpha parameter
326  float alpha = -1.0f;
327  float tmp;
328  for (const auto& str : obj_param_) {
329  auto tokens = treelite::common::Split(str, ':');
330  if (tokens.size() == 2 && tokens[0] == "sigmoid"
331  && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
332  alpha = tmp;
333  break;
334  }
335  }
336  CHECK_GT(alpha, 0.0f)
337  << "Ill-formed LightGBM model file: not a valid binary objective";
338 
339  model.param.pred_transform = "sigmoid";
340  model.param.sigmoid_alpha = alpha;
341  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
342  model.param.pred_transform = "sigmoid";
343  model.param.sigmoid_alpha = 1.0f;
344  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
345  model.param.pred_transform = "logarithm_one_plus_exp";
346  } else {
347  model.param.pred_transform = "identity";
348  }
349 
350  // traverse trees
351  for (const auto& lgb_tree : lgb_trees_) {
352  model.trees.emplace_back();
353  treelite::Tree& tree = model.trees.back();
354  tree.Init();
355 
356  // assign node ID's so that a breadth-wise traversal would yield
357  // the monotonic sequence 0, 1, 2, ...
358  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
359  Q.push({0, 0});
360  while (!Q.empty()) {
361  int old_id, new_id;
362  std::tie(old_id, new_id) = Q.front(); Q.pop();
363  if (old_id < 0) { // leaf
364  const double leaf_value = lgb_tree.leaf_value[~old_id];
365  const int data_count = lgb_tree.leaf_count[~old_id];
366  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
367  CHECK_GE(data_count, 0);
368  tree[new_id].set_data_count(static_cast<size_t>(data_count));
369  } else { // non-leaf
370  const int data_count = lgb_tree.internal_count[old_id];
371  const unsigned split_index =
372  static_cast<unsigned>(lgb_tree.split_feature[old_id]);
373  const bool default_left
374  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
375  tree.AddChilds(new_id);
376  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
377  // categorical
378  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
379  const std::vector<uint32_t> left_categories
380  = BitsetToList(lgb_tree.cat_threshold.data()
381  + lgb_tree.cat_boundaries[cat_idx],
382  lgb_tree.cat_boundaries[cat_idx + 1]
383  - lgb_tree.cat_boundaries[cat_idx]);
384  tree[new_id].set_categorical_split(split_index, default_left,
385  left_categories);
386  } else {
387  // numerical
388  const treelite::tl_float threshold =
389  static_cast<treelite::tl_float>(lgb_tree.threshold[old_id]);
390  const treelite::Operator cmp_op = treelite::Operator::kLE;
391  tree[new_id].set_numerical_split(split_index, threshold,
392  default_left, cmp_op);
393  }
394  CHECK_GE(data_count, 0);
395  tree[new_id].set_data_count(static_cast<size_t>(data_count));
396  tree[new_id].set_gain(static_cast<double>(lgb_tree.split_gain[old_id]));
397  Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
398  Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
399  }
400  }
401  }
402  LOG(INFO) << "model.num_tree = " << model.trees.size();
403  return model;
404 }
405 
406 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree.h:299
thin wrapper for tree ensemble model
Definition: tree.h:415
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:22
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:25
double tl_float
float type to be used internally
Definition: base.h:17
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:309
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:422
Operator
comparison operators
Definition: base.h:23