treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <treelite/tree.h>
10 #include <unordered_map>
11 #include <queue>
12 
13 namespace {
14 
15 treelite::Model ParseStream(dmlc::Stream* fi);
16 
17 } // namespace anonymous
18 
19 namespace treelite {
20 namespace frontend {
21 
22 DMLC_REGISTRY_FILE_TAG(lightgbm);
23 
24 Model LoadLightGBMModel(const char* filename) {
25  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
26  return ParseStream(fi.get());
27 }
28 
29 } // namespace frontend
30 } // namespace treelite
31 
32 /* auxiliary data structures to interpret lightgbm model file */
33 namespace {
34 
35 enum Masks : uint8_t {
36  kCategoricalMask = 1,
37  kDefaultLeftMask = 2
38 };
39 
40 struct LGBTree {
41  int num_leaves;
42  int num_cat; // number of categorical splits
43  std::vector<double> leaf_value;
44  std::vector<int8_t> decision_type;
45  std::vector<int> cat_boundaries;
46  std::vector<uint32_t> cat_threshold;
47  std::vector<int> split_feature;
48  std::vector<double> threshold;
49  std::vector<int> left_child;
50  std::vector<int> right_child;
51 };
52 
53 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
54  return (decision_type & mask) > 0;
55 }
56 
57 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
58  uint8_t nslots) {
59  std::vector<uint32_t> result;
60  const uint32_t nbits = static_cast<uint32_t>(nslots) * 32;
61  for (uint32_t i = 0; i < nbits; ++i) {
62  const uint32_t i1 = i / 32;
63  const uint32_t i2 = i % 32;
64  if ((bits[i1] >> i2) & 1) {
65  result.push_back(i);
66  }
67  }
68  return result;
69 }
70 
71 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
72  const size_t bufsize = 16 * 1024 * 1024; // 16 MB
73  std::vector<char> buf(bufsize);
74 
75  std::vector<std::string> lines;
76 
77  size_t byte_read;
78 
79  std::string leftover = ""; // carry over between buffers
80  while ( (byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) {
81  size_t i = 0;
82  size_t tok_begin = 0;
83  while (i < byte_read) {
84  if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines
85  if (tok_begin == 0 && leftover.length() + i > 0) {
86  // first line in buffer
87  lines.push_back(leftover + std::string(&buf[0], i));
88  leftover = "";
89  } else {
90  lines.emplace_back(&buf[tok_begin], i - tok_begin);
91  }
92  // skip all delimiters afterwards
93  for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i);
94  tok_begin = i;
95  } else {
96  ++i;
97  }
98  }
99  // left-over string
100  leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
101  }
102 
103  if (!leftover.empty()) {
104  LOG(INFO)
105  << "Warning: input file was not terminated with end-of-line character.";
106  lines.push_back(leftover);
107  }
108 
109  return lines;
110 }
111 
112 inline treelite::Model ParseStream(dmlc::Stream* fi) {
113  std::vector<LGBTree> lgb_trees_;
114  int max_feature_idx_;
115  int num_tree_per_iteration_;
116  bool average_output_;
117  std::string obj_name_;
118  std::vector<std::string> obj_param_;
119 
120  /* 1. Parse input stream */
121  std::vector<std::string> lines = LoadText(fi);
122  std::unordered_map<std::string, std::string> global_dict;
123  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
124 
125  bool in_tree = false; // is current entry part of a tree?
126  for (const auto& line : lines) {
127  std::istringstream ss(line);
128  std::string key, value, rest;
129  std::getline(ss, key, '=');
130  std::getline(ss, value, '=');
131  std::getline(ss, rest);
132  CHECK(rest.empty()) << "Ill-formed LightGBM model file";
133  if (key == "Tree") {
134  in_tree = true;
135  tree_dict.emplace_back();
136  } else {
137  if (in_tree) {
138  tree_dict.back()[key] = value;
139  } else {
140  global_dict[key] = value;
141  }
142  }
143  }
144 
145  {
146  auto it = global_dict.find("objective");
147  CHECK(it != global_dict.end())
148  << "Ill-formed LightGBM model file: need objective";
149  auto obj_strs = treelite::common::Split(it->second, ' ');
150  obj_name_ = obj_strs[0];
151  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
152 
153  it = global_dict.find("max_feature_idx");
154  CHECK(it != global_dict.end())
155  << "Ill-formed LightGBM model file: need max_feature_idx";
156  max_feature_idx_ = treelite::common::TextToNumber<int>(it->second);
157  it = global_dict.find("num_tree_per_iteration");
158  CHECK(it != global_dict.end())
159  << "Ill-formed LightGBM model file: need num_tree_per_iteration";
160  num_tree_per_iteration_ = treelite::common::TextToNumber<int>(it->second);
161 
162  it = global_dict.find("average_output");
163  average_output_ = (it != global_dict.end());
164  }
165 
166  for (const auto& dict : tree_dict) {
167  lgb_trees_.emplace_back();
168  LGBTree& tree = lgb_trees_.back();
169 
170  auto it = dict.find("num_leaves");
171  CHECK(it != dict.end())
172  << "Ill-formed LightGBM model file: need num_leaves";
173  tree.num_leaves = treelite::common::TextToNumber<int>(it->second);
174 
175  it = dict.find("num_cat");
176  CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
177  tree.num_cat = treelite::common::TextToNumber<int>(it->second);
178 
179  it = dict.find("leaf_value");
180  CHECK(it != dict.end())
181  << "Ill-formed LightGBM model file: need leaf_value";
182  tree.leaf_value
183  = treelite::common::TextToArray<double>(it->second, tree.num_leaves);
184 
185  it = dict.find("decision_type");
186  CHECK(it != dict.end())
187  << "Ill-formed LightGBM model file: need decision_type";
188  if (it == dict.end()) {
189  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
190  } else {
191  tree.decision_type
192  = treelite::common::TextToArray<int8_t>(it->second,
193  tree.num_leaves - 1);
194  }
195 
196  if (tree.num_cat > 0) {
197  it = dict.find("cat_boundaries");
198  CHECK(it != dict.end())
199  << "Ill-formed LightGBM model file: need cat_boundaries";
200  tree.cat_boundaries
201  = treelite::common::TextToArray<int>(it->second, tree.num_cat + 1);
202  it = dict.find("cat_threshold");
203  CHECK(it != dict.end())
204  << "Ill-formed LightGBM model file: need cat_threshold";
205  tree.cat_threshold
206  = treelite::common::TextToArray<uint32_t>(it->second,
207  tree.cat_boundaries.back());
208  }
209 
210  it = dict.find("split_feature");
211  CHECK(it != dict.end())
212  << "Ill-formed LightGBM model file: need split_feature";
213  tree.split_feature
214  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
215 
216  it = dict.find("threshold");
217  CHECK(it != dict.end())
218  << "Ill-formed LightGBM model file: need threshold";
219  tree.threshold
220  = treelite::common::TextToArray<double>(it->second, tree.num_leaves - 1);
221 
222  it = dict.find("left_child");
223  CHECK(it != dict.end())
224  << "Ill-formed LightGBM model file: need left_child";
225  tree.left_child
226  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
227 
228  it = dict.find("right_child");
229  CHECK(it != dict.end())
230  << "Ill-formed LightGBM model file: need right_child";
231  tree.right_child
232  = treelite::common::TextToArray<int>(it->second, tree.num_leaves - 1);
233  }
234 
235  /* 2. Export model */
236  treelite::Model model;
237  model.num_feature = max_feature_idx_ + 1;
238  model.num_output_group = num_tree_per_iteration_;
239  if (model.num_output_group > 1) {
240  // multiclass classification with gradient boosted trees
241  CHECK(!average_output_)
242  << "Ill-formed LightGBM model file: cannot use random forest mode "
243  << "for multi-class classification";
244  model.random_forest_flag = false;
245  } else {
246  model.random_forest_flag = average_output_;
247  }
248 
249  // set correct prediction transform function, depending on objective function
250  if (obj_name_ == "multiclass") {
251  // validate num_class parameter
252  int num_class = -1;
253  int tmp;
254  for (const auto& str : obj_param_) {
255  auto tokens = treelite::common::Split(str, ':');
256  if (tokens.size() == 2 && tokens[0] == "num_class"
257  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
258  num_class = tmp;
259  break;
260  }
261  }
262  CHECK(num_class >= 0 && num_class == model.num_output_group)
263  << "Ill-formed LightGBM model file: not a valid multiclass objective";
264 
265  model.param.pred_transform = "softmax";
266  } else if (obj_name_ == "multiclassova") {
267  // validate num_class and alpha parameters
268  int num_class = -1;
269  float alpha = -1.0f;
270  int tmp;
271  float tmp2;
272  for (const auto& str : obj_param_) {
273  auto tokens = treelite::common::Split(str, ':');
274  if (tokens.size() == 2) {
275  if (tokens[0] == "num_class"
276  && (tmp = treelite::common::TextToNumber<int>(tokens[1])) >= 0) {
277  num_class = tmp;
278  } else if (tokens[0] == "sigmoid"
279  && (tmp2 = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
280  alpha = tmp2;
281  }
282  }
283  }
284  CHECK(num_class >= 0 && num_class == model.num_output_group
285  && alpha > 0.0f)
286  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
287 
288  model.param.pred_transform = "multiclass_ova";
289  model.param.sigmoid_alpha = alpha;
290  } else if (obj_name_ == "binary") {
291  // validate alpha parameter
292  float alpha = -1.0f;
293  float tmp;
294  for (const auto& str : obj_param_) {
295  auto tokens = treelite::common::Split(str, ':');
296  if (tokens.size() == 2 && tokens[0] == "sigmoid"
297  && (tmp = treelite::common::TextToNumber<float>(tokens[1])) > 0.0f) {
298  alpha = tmp;
299  break;
300  }
301  }
302  CHECK_GT(alpha, 0.0f)
303  << "Ill-formed LightGBM model file: not a valid binary objective";
304 
305  model.param.pred_transform = "sigmoid";
306  model.param.sigmoid_alpha = alpha;
307  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
308  model.param.pred_transform = "sigmoid";
309  model.param.sigmoid_alpha = 1.0f;
310  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
311  model.param.pred_transform = "logarithm_one_plus_exp";
312  } else {
313  model.param.pred_transform = "identity";
314  }
315 
316  // traverse trees
317  for (const auto& lgb_tree : lgb_trees_) {
318  model.trees.emplace_back();
319  treelite::Tree& tree = model.trees.back();
320  tree.Init();
321 
322  // assign node ID's so that a breadth-wise traversal would yield
323  // the monotonic sequence 0, 1, 2, ...
324  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
325  Q.push({0, 0});
326  while (!Q.empty()) {
327  int old_id, new_id;
328  std::tie(old_id, new_id) = Q.front(); Q.pop();
329  if (old_id < 0) { // leaf
330  const double leaf_value = lgb_tree.leaf_value[~old_id];
331  tree[new_id].set_leaf(static_cast<treelite::tl_float>(leaf_value));
332  } else { // non-leaf
333  const unsigned split_index =
334  static_cast<unsigned>(lgb_tree.split_feature[old_id]);
335  const bool default_left
336  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
337  tree.AddChilds(new_id);
338  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
339  // categorical
340  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
341  const std::vector<uint32_t> left_categories
342  = BitsetToList(lgb_tree.cat_threshold.data()
343  + lgb_tree.cat_boundaries[cat_idx],
344  lgb_tree.cat_boundaries[cat_idx + 1]
345  - lgb_tree.cat_boundaries[cat_idx]);
346  tree[new_id].set_categorical_split(split_index, default_left,
347  left_categories);
348  } else {
349  // numerical
350  const treelite::tl_float threshold =
351  static_cast<treelite::tl_float>(lgb_tree.threshold[old_id]);
352  const treelite::Operator cmp_op = treelite::Operator::kLE;
353  tree[new_id].set_numerical_split(split_index, threshold,
354  default_left, cmp_op);
355  }
356  Q.push({lgb_tree.left_child[old_id], tree[new_id].cleft()});
357  Q.push({lgb_tree.right_child[old_id], tree[new_id].cright()});
358  }
359  }
360  }
361  return model;
362 }
363 
364 } // namespace anonymous
void Init()
initialize the model with a single root node
Definition: tree.h:235
thin wrapper for tree ensemble model
Definition: tree.h:351
float tl_float
float type to be used internally
Definition: base.h:17
model structure for tree
in-memory representation of a decision tree
Definition: tree.h:19
Model LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:24
void AddChilds(int nid)
add child nodes to node
Definition: tree.h:245
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:358
Operator
comparison operators
Definition: base.h:23