treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <treelite/tree.h>
10 #include <unordered_map>
11 #include <queue>
12 
13 namespace {
14 
15 treelite::Model ParseStream(dmlc::Stream* fi);
16 
17 } // namespace anonymous
18 
19 namespace treelite {
20 namespace frontend {
21 
22 DMLC_REGISTRY_FILE_TAG(lightgbm);
23 
24 Model LoadLightGBMModel(const char* filename) {
25  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
26  return ParseStream(fi.get());
27 }
28 
29 } // namespace frontend
30 } // namespace treelite
31 
32 /* auxiliary data structures to interpret lightgbm model file */
33 namespace {
34 
35 enum Masks : uint8_t {
36  kCategoricalMask = 1,
37  kDefaultLeftMask = 2
38 };
39 
40 struct LGBTree {
41  int num_leaves;
42  int num_cat; // number of categorical splits
43  std::vector<double> leaf_value;
44  std::vector<int8_t> decision_type;
45  std::vector<int> cat_boundaries;
46  std::vector<uint32_t> cat_threshold;
47  std::vector<int> split_feature;
48  std::vector<double> threshold;
49  std::vector<int> left_child;
50  std::vector<int> right_child;
51 };
52 
53 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
54  return (decision_type & mask) > 0;
55 }
56 
57 inline std::vector<uint8_t> BitsetToList(const uint32_t* bits,
58  uint8_t nslots) {
59  std::vector<uint8_t> result;
60  CHECK(nslots == 1 || nslots == 2);
61  const uint8_t nbits = nslots * 32;
62  for (uint8_t i = 0; i < nbits; ++i) {
63  const uint8_t i1 = i / 32;
64  const uint8_t i2 = i % 32;
65  if ((bits[i1] >> i2) & 1) {
66  result.push_back(i);
67  }
68  }
69  return result;
70 }
71 
72 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
73  const size_t bufsize = 16 * 1024 * 1024; // 16 MB
74  std::vector<char> buf(bufsize);
75 
76  std::vector<std::string> lines;
77 
78  size_t byte_read;
79 
80  std::string leftover = ""; // carry over between buffers
81  while ( (byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) {
82  size_t i = 0;
83  size_t tok_begin = 0;
84  while (i < byte_read) {
85  if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines
86  if (tok_begin == 0 && leftover.length() + i > 0) {
87  // first line in buffer
88  lines.push_back(leftover + std::string(&buf[0], i));
89  leftover = "";
90  } else {
91  lines.emplace_back(&buf[tok_begin], i - tok_begin);
92  }
93  // skip all delimiters afterwards
94  for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i);
95  tok_begin = i;
96  } else {
97  ++i;
98  }
99  }
100  // left-over string
101  leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
102  }
103 
104  if (!leftover.empty()) {
105  LOG(INFO)
106  << "Warning: input file was not terminated with end-of-line character.";
107  lines.push_back(leftover);
108  }
109 
110  return lines;
111 }
112 
113 inline treelite::Model ParseStream(dmlc::Stream* fi) {
114  std::vector<LGBTree> lgb_trees_;
115  int max_feature_idx_;
116  int num_tree_per_iteration_;
117  bool average_output_;
118  std::string obj_name_;
119  std::vector<std::string> obj_param_;
120 
121  /* 1. Parse input stream */
122  std::vector<std::string> lines = LoadText(fi);
123  std::unordered_map<std::string, std::string> global_dict;
124  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
125 
126  bool in_tree = false; // is current entry part of a tree?
127  for (const auto& line : lines) {
128  std::istringstream ss(line);
129  std::string key, value, rest;
130  std::getline(ss, key, '=');
131  std::getline(ss, value, '=');
132  std::getline(ss, rest);
133  CHECK(rest.empty()) << "Ill-formed LightGBM model file";
134  if (key == "Tree") {
135  in_tree = true;
136  tree_dict.emplace_back();
137  } else {
138  if (in_tree) {
139  tree_dict.back()[key] = value;
140  } else {
141  global_dict[key] = value;
142  }
143  }
144  }
145 
146  {
147  auto it = global_dict.find("objective");
148  CHECK(it != global_dict.end())
149  << "Ill-formed LightGBM model file: need objective";
150  auto obj_strs = treelite::common::Split(it->second, ' ');
151  obj_name_ = obj_strs[0];
152  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
153 
154  it = global_dict.find("max_feature_idx");
155  CHECK(it != global_dict.end())
156  << "Ill-formed LightGBM model file: need max_feature_idx";
157  max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
158  it = global_dict.find("num_tree_per_iteration");
159  CHECK(it != global_dict.end())
160  << "Ill-formed LightGBM model file: need num_tree_per_iteration";
161  num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
162 
163  it = global_dict.find("average_output");
164  average_output_ = (it != global_dict.end());
165  }
166 
167  for (const auto& dict : tree_dict) {
168  lgb_trees_.emplace_back();
169  LGBTree& tree = lgb_trees_.back();
170 
171  auto it = dict.find("num_leaves");
172  CHECK(it != dict.end())
173  << "Ill-formed LightGBM model file: need num_leaves";
174  tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
175 
176  it = dict.find("num_cat");
177  CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
178  tree.num_cat = treelite::common::TextToNumber<int>(it->second);
179 
180  it = dict.find("leaf_value");
181  CHECK(it != dict.end())
182  << "Ill-formed LightGBM model file: need leaf_value";
183  tree.leaf_value
184  = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
185 
186  it = dict.find("decision_type");
187  CHECK(it != dict.end())
188  << "Ill-formed LightGBM model file: need decision_type";
189  if (it == dict.end()) {
190  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
191  } else {
192  tree.decision_type
193  = treelite::common::TextToArray<int8_t>(it->second,
194  tree.num_leaves - 1);
195  }
196 
197  if (tree.num_cat > 0) {
198  it = dict.find("cat_boundaries");
199  CHECK(it != dict.end())
200  << "Ill-formed LightGBM model file: need cat_boundaries";
201  tree.cat_boundaries
202  = treelite::common::TextToArray<int>(it->second, tree.num_cat + 1);
203  it = dict.find("cat_threshold");
204  CHECK(it != dict.end())
205  << "Ill-formed LightGBM model file: need cat_threshold";
206  tree.cat_threshold
207  = treelite::common::TextToArray<uint32_t>(it->second,
208  tree.cat_boundaries.back());
209  }
210 
211  it = dict.find("split_feature");
212  CHECK(it != dict.end())
213  << "Ill-formed LightGBM model file: need split_feature";
214  tree.split_feature
215  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
216 
217  it = dict.find("threshold");
218  CHECK(it != dict.end())
219  << "Ill-formed LightGBM model file: need threshold";
220  tree.threshold
221  = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
222 
223  it = dict.find("left_child");
224  CHECK(it != dict.end())
225  << "Ill-formed LightGBM model file: need left_child";
226  tree.left_child
227  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
228 
229  it = dict.find("right_child");
230  CHECK(it != dict.end())
231  << "Ill-formed LightGBM model file: need right_child";
232  tree.right_child
233  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
234  }
235 
236  /* 2. Export model */
237  treelite::Model model;
238  model.num_feature = max_feature_idx_ + 1;
239  model.num_output_group = num_tree_per_iteration_;
240  if (model.num_output_group > 1) {
241  // multiclass classification with gradient boosted trees
242  CHECK(!average_output_)
243  << "Ill-formed LightGBM model file: cannot use random forest mode "
244  << "for multi-class classification";
245  model.random_forest_flag = false;
246  } else {
247  model.random_forest_flag = average_output_;
248  }
249 
250  // set correct prediction transform function, depending on objective function
251  if (obj_name_ == "multiclass") {
252  // validate num_class parameter
253  int num_class = -1;
254  int tmp;
255  for (const auto& str : obj_param_) {
256  auto tokens = treelite::common::Split(str, ':');
257  if (tokens.size() == 2 && tokens[0] == "num_class"
258  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
259  num_class = tmp;
260  break;
261  }
262  }
263  CHECK(num_class >= 0 && num_class == model.num_output_group)
264  << "Ill-formed LightGBM model file: not a valid multiclass objective";
265 
266  model.param.pred_transform = "softmax";
267  } else if (obj_name_ == "multiclassova") {
268  // validate num_class and alpha parameters
269  int num_class = -1;
270  float alpha = -1.0f;
271  int tmp;
272  float tmp2;
273  for (const auto& str : obj_param_) {
274  auto tokens = treelite::common::Split(str, ':');
275  if (tokens.size() == 2) {
276  if (tokens[0] == "num_class"
277  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
278  num_class = tmp;
279  } else if (tokens[0] == "sigmoid"
280  && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
281  alpha = tmp2;
282  }
283  }
284  }
285  CHECK(num_class >= 0 && num_class == model.num_output_group
286  && alpha > 0.0f)
287  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
288 
289  model.param.pred_transform = "multiclass_ova";
290  model.param.sigmoid_alpha = alpha;
291  } else if (obj_name_ == "binary") {
292  // validate alpha parameter
293  float alpha = -1.0f;
294  float tmp;
295  for (const auto& str : obj_param_) {
296  auto tokens = treelite::common::Split(str, ':');
297  if (tokens.size() == 2 && tokens[0] == "sigmoid"
298  && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
299  alpha = tmp;
300  break;
301  }
302  }
303  CHECK_GT(alpha, 0.0f)
304  << "Ill-formed LightGBM model file: not a valid binary objective";
305 
306  model.param.pred_transform = "sigmoid";
307  model.param.sigmoid_alpha = alpha;
308  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
309  model.param.pred_transform = "sigmoid";
310  model.param.sigmoid_alpha = 1.0f;
311  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
312  model.param.pred_transform = "logarithm_one_plus_exp";
313  } else {
314  model.param.pred_transform = "identity";
315  }
316 
317  // traverse trees
318  for (const auto& lgb_tree : lgb_trees_) {
319  model.trees.emplace_back();
320  treelite::Tree& tree = model.trees.back();
321  tree.Init();
322 
323  // assign node ID's so that a breadth-wise traversal would yield
324  // the monotonic sequence 0, 1, 2, ...
325  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
326  Q.push({0, 0});
327  while (!Q.empty()) {
328  int old_id, new_id;
329  std::tie(old_id, new_id) = Q.front(); Q.pop();
330  if (old_id < 0) { // leaf
331  const double leaf_value = lgb_tree.leaf_value[~old_id];
332  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
333  } else { // non-leaf
334  const unsigned split_index =
335  static_cast<unsigned>(lgb_tree.split_feature[old_id]);
336  const bool default_left
337  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
338  tree.AddChilds(new_id);
339  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
340  // categorical
341  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
342  CHECK_LE(lgb_tree.cat_boundaries[cat_idx + 1]
343  - lgb_tree.cat_boundaries[cat_idx], 2)
344  << "Categorical features must have 64 categories or fewer.";
345  const std::vector<uint8_t> left_categories
346  = BitsetToList(lgb_tree.cat_threshold.data()
347  + lgb_tree.cat_boundaries[cat_idx],
348  lgb_tree.cat_boundaries[cat_idx + 1]
349  - lgb_tree.cat_boundaries[cat_idx]);
350  tree[new_id].set_categorical_split(split_index, default_left,
351  left_categories);
352  } else {
353  // numerical
354  const treelite::tl_float threshold =
355  static_cast<treelite::tl_float>(lgb_tree.threshold[old_id]);
356  const treelite::Operator cmp_op = treelite::Operator::kLE;
357  tree[new_id].set_numerical_split(split_index, threshold,
358  default_left, cmp_op);
359  }
360  Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
361  Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
362  }
363  }
364  }
365  return model;
366 }
367 
368 } // namespace anonymous
void Init()
initialize the model with a single root node
Definition: tree.h:234
thin wrapper for tree ensemble model
Definition: tree.h:350
float tl_float
float type to be used internally
Definition: base.h:17
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:19
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:24
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:244
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:357
Operator
comparison operators
Definition: base.h:23