Treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <dmlc/data.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <unordered_map>
12 #include <limits>
13 #include <queue>
14 
15 namespace {
16 
17 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi);
18 
19 } // anonymous namespace
20 
21 namespace treelite {
22 namespace frontend {
23 
24 DMLC_REGISTRY_FILE_TAG(lightgbm);
25 
26 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char *filename) {
27  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(filename, "r"));
28  return ParseStream(fi.get());
29 }
30 
31 } // namespace frontend
32 } // namespace treelite
33 
34 /* auxiliary data structures to interpret lightgbm model file */
35 namespace {
36 
37 template <typename T>
38 inline T TextToNumber(const std::string& str) {
39  static_assert(std::is_same<T, float>::value
40  || std::is_same<T, double>::value
41  || std::is_same<T, int>::value
42  || std::is_same<T, int8_t>::value
43  || std::is_same<T, uint32_t>::value
44  || std::is_same<T, uint64_t>::value,
45  "unsupported data type for TextToNumber; use float, double, "
46  "int, int8_t, uint32_t, or uint64_t");
47 }
48 
49 template <>
50 inline float TextToNumber(const std::string& str) {
51  errno = 0;
52  char *endptr;
53  float val = std::strtof(str.c_str(), &endptr);
54  if (errno == ERANGE) {
55  LOG(FATAL) << "Range error while converting string to double";
56  } else if (errno != 0) {
57  LOG(FATAL) << "Unknown error";
58  } else if (*endptr != '\0') {
59  LOG(FATAL) << "String does not represent a valid floating-point number";
60  }
61  return val;
62 }
63 
64 template <>
65 inline double TextToNumber(const std::string& str) {
66  errno = 0;
67  char *endptr;
68  double val = std::strtod(str.c_str(), &endptr);
69  if (errno == ERANGE) {
70  LOG(FATAL) << "Range error while converting string to double";
71  } else if (errno != 0) {
72  LOG(FATAL) << "Unknown error";
73  } else if (*endptr != '\0') {
74  LOG(FATAL) << "String does not represent a valid floating-point number";
75  }
76  return val;
77 }
78 
79 template <>
80 inline int TextToNumber(const std::string& str) {
81  errno = 0;
82  char *endptr;
83  auto val = std::strtol(str.c_str(), &endptr, 10);
84  if (errno == ERANGE || val < std::numeric_limits<int>::min()
85  || val > std::numeric_limits<int>::max()) {
86  LOG(FATAL) << "Range error while converting string to int";
87  } else if (errno != 0) {
88  LOG(FATAL) << "Unknown error";
89  } else if (*endptr != '\0') {
90  LOG(FATAL) << "String does not represent a valid integer";
91  }
92  return static_cast<int>(val);
93 }
94 
95 template <>
96 inline int8_t TextToNumber(const std::string& str) {
97  errno = 0;
98  char *endptr;
99  auto val = std::strtol(str.c_str(), &endptr, 10);
100  if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
101  || val > std::numeric_limits<int8_t>::max()) {
102  LOG(FATAL) << "Range error while converting string to int8_t";
103  } else if (errno != 0) {
104  LOG(FATAL) << "Unknown error";
105  } else if (*endptr != '\0') {
106  LOG(FATAL) << "String does not represent a valid integer";
107  }
108  return static_cast<int8_t>(val);
109 }
110 
111 template <>
112 inline uint32_t TextToNumber(const std::string& str) {
113  errno = 0;
114  char *endptr;
115  auto val = std::strtoul(str.c_str(), &endptr, 10);
116  if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
117  LOG(FATAL) << "Range error while converting string to uint32_t";
118  } else if (errno != 0) {
119  LOG(FATAL) << "Unknown error";
120  } else if (*endptr != '\0') {
121  LOG(FATAL) << "String does not represent a valid integer";
122  }
123  return static_cast<uint32_t>(val);
124 }
125 
126 template <>
127 inline uint64_t TextToNumber(const std::string& str) {
128  errno = 0;
129  char *endptr;
130  auto val = std::strtoull(str.c_str(), &endptr, 10);
131  if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
132  LOG(FATAL) << "Range error while converting string to uint64_t";
133  } else if (errno != 0) {
134  LOG(FATAL) << "Unknown error";
135  } else if (*endptr != '\0') {
136  LOG(FATAL) << "String does not represent a valid integer";
137  }
138  return static_cast<uint64_t>(val);
139 }
140 
141 inline std::vector<std::string> Split(const std::string& text, char delim) {
142  std::vector<std::string> array;
143  std::istringstream ss(text);
144  std::string token;
145  while (std::getline(ss, token, delim)) {
146  array.push_back(token);
147  }
148  return array;
149 }
150 
151 template <typename T>
152 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
153  if (text.empty() && num_entry > 0) {
154  LOG(FATAL) << "Cannot convert empty text into array";
155  }
156  std::vector<T> array;
157  std::istringstream ss(text);
158  std::string token;
159  for (int i = 0; i < num_entry; ++i) {
160  std::getline(ss, token, ' ');
161  array.push_back(TextToNumber<T>(token));
162  }
163  return array;
164 }
165 
166 enum Masks : uint8_t {
167  kCategoricalMask = 1,
168  kDefaultLeftMask = 2
169 };
170 
171 enum class MissingType : uint8_t {
172  kNone,
173  kZero,
174  kNaN
175 };
176 
177 struct LGBTree {
178  int num_leaves;
179  int num_cat; // number of categorical splits
180  std::vector<double> leaf_value;
181  std::vector<int8_t> decision_type;
182  std::vector<uint64_t> cat_boundaries;
183  std::vector<uint32_t> cat_threshold;
184  std::vector<int> split_feature;
185  std::vector<double> threshold;
186  std::vector<int> left_child;
187  std::vector<int> right_child;
188  std::vector<float> split_gain;
189  std::vector<int> internal_count;
190  std::vector<int> leaf_count;
191 };
192 
193 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
194  return (decision_type & mask) > 0;
195 }
196 
197 inline MissingType GetMissingType(int8_t decision_type) {
198  return static_cast<MissingType>((decision_type >> 2) & 3);
199 }
200 
201 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
202  size_t nslots) {
203  std::vector<uint32_t> result;
204  const size_t nbits = nslots * 32;
205  for (size_t i = 0; i < nbits; ++i) {
206  const size_t i1 = i / 32;
207  const uint32_t i2 = static_cast<uint32_t>(i % 32);
208  if ((bits[i1] >> i2) & 1) {
209  result.push_back(static_cast<uint32_t>(i));
210  }
211  }
212  return result;
213 }
214 
215 inline std::vector<std::string> LoadText(dmlc::Stream* fi) {
216  const size_t bufsize = 16 * 1024 * 1024; // 16 MB
217  std::vector<char> buf(bufsize);
218 
219  std::vector<std::string> lines;
220 
221  size_t byte_read;
222 
223  std::string leftover = ""; // carry over between buffers
224  while ((byte_read = fi->Read(&buf[0], sizeof(char) * bufsize)) > 0) {
225  size_t i = 0;
226  size_t tok_begin = 0;
227  while (i < byte_read) {
228  if (buf[i] == '\n' || buf[i] == '\r') { // delimiter for lines
229  if (tok_begin == 0 && leftover.length() + i > 0) {
230  // first line in buffer
231  lines.push_back(leftover + std::string(&buf[0], i));
232  leftover = "";
233  } else {
234  lines.emplace_back(&buf[tok_begin], i - tok_begin);
235  }
236  // skip all delimiters afterwards
237  for (; (buf[i] == '\n' || buf[i] == '\r') && i < byte_read; ++i) {}
238  tok_begin = i;
239  } else {
240  ++i;
241  }
242  }
243  // left-over string
244  leftover += std::string(&buf[tok_begin], byte_read - tok_begin);
245  }
246 
247  if (!leftover.empty()) {
248  LOG(INFO)
249  << "Warning: input file was not terminated with end-of-line character.";
250  lines.push_back(leftover);
251  }
252 
253  return lines;
254 }
255 
256 inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
257  std::vector<LGBTree> lgb_trees_;
258  int max_feature_idx_;
259  int num_class_;
260  bool average_output_;
261  std::string obj_name_;
262  std::vector<std::string> obj_param_;
263 
264  /* 1. Parse input stream */
265  std::vector<std::string> lines = LoadText(fi);
266  std::unordered_map<std::string, std::string> global_dict;
267  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
268 
269  bool in_tree = false; // is current entry part of a tree?
270  for (const auto& line : lines) {
271  std::istringstream ss(line);
272  std::string key, value, rest;
273  std::getline(ss, key, '=');
274  std::getline(ss, value, '=');
275  std::getline(ss, rest);
276  CHECK(rest.empty()) << "Ill-formed LightGBM model file";
277  if (key == "Tree") {
278  in_tree = true;
279  tree_dict.emplace_back();
280  } else {
281  if (in_tree) {
282  tree_dict.back()[key] = value;
283  } else {
284  global_dict[key] = value;
285  }
286  }
287  }
288 
289  {
290  auto it = global_dict.find("objective");
291  if (it == global_dict.end()) { // custom objective (fobj)
292  obj_name_ = "custom";
293  } else {
294  auto obj_strs = Split(it->second, ' ');
295  obj_name_ = obj_strs[0];
296  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
297  }
298 
299  it = global_dict.find("max_feature_idx");
300  CHECK(it != global_dict.end())
301  << "Ill-formed LightGBM model file: need max_feature_idx";
302  max_feature_idx_ = TextToNumber<int>(it->second);
303  it = global_dict.find("num_class");
304  CHECK(it != global_dict.end())
305  << "Ill-formed LightGBM model file: need num_class";
306  num_class_ = TextToNumber<int>(it->second);
307 
308  it = global_dict.find("average_output");
309  average_output_ = (it != global_dict.end());
310  }
311 
312  for (const auto& dict : tree_dict) {
313  lgb_trees_.emplace_back();
314  LGBTree& tree = lgb_trees_.back();
315 
316  auto it = dict.find("num_leaves");
317  CHECK(it != dict.end())
318  << "Ill-formed LightGBM model file: need num_leaves";
319  tree.num_leaves = TextToNumber<int>(it->second);
320 
321  it = dict.find("num_cat");
322  CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
323  tree.num_cat = TextToNumber<int>(it->second);
324 
325  it = dict.find("leaf_value");
326  CHECK(it != dict.end() && !it->second.empty())
327  << "Ill-formed LightGBM model file: need leaf_value";
328  tree.leaf_value
329  = TextToArray<double>(it->second, tree.num_leaves);
330 
331  it = dict.find("decision_type");
332  if (tree.num_leaves <= 1) {
333  tree.decision_type = std::vector<int8_t>();
334  } else {
335  CHECK_GT(tree.num_leaves, 1);
336  if (it == dict.end()) {
337  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
338  } else {
339  CHECK(!it->second.empty())
340  << "Ill-formed LightGBM model file: decision_type cannot be empty string";
341  tree.decision_type = TextToArray<int8_t>(it->second, tree.num_leaves - 1);
342  }
343  }
344 
345  if (tree.num_cat > 0) {
346  it = dict.find("cat_boundaries");
347  CHECK(it != dict.end() && !it->second.empty())
348  << "Ill-formed LightGBM model file: need cat_boundaries";
349  tree.cat_boundaries
350  = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
351  it = dict.find("cat_threshold");
352  CHECK(it != dict.end() && !it->second.empty())
353  << "Ill-formed LightGBM model file: need cat_threshold";
354  tree.cat_threshold
355  = TextToArray<uint32_t>(it->second, static_cast<uint32_t>(tree.cat_boundaries.back()));
356  }
357 
358  it = dict.find("split_feature");
359  if (tree.num_leaves <= 1) {
360  tree.split_feature = std::vector<int>();
361  } else {
362  CHECK_GT(tree.num_leaves, 1);
363  CHECK(it != dict.end() && !it->second.empty())
364  << "Ill-formed LightGBM model file: need split_feature";
365  tree.split_feature = TextToArray<int>(it->second, tree.num_leaves - 1);
366  }
367 
368  it = dict.find("threshold");
369  if (tree.num_leaves <= 1) {
370  tree.threshold = std::vector<double>();
371  } else {
372  CHECK_GT(tree.num_leaves, 1);
373  CHECK(it != dict.end() && !it->second.empty())
374  << "Ill-formed LightGBM model file: need threshold";
375  tree.threshold = TextToArray<double>(it->second, tree.num_leaves - 1);
376  }
377 
378  it = dict.find("split_gain");
379  if (tree.num_leaves <= 1) {
380  tree.split_gain = std::vector<float>();
381  } else {
382  CHECK_GT(tree.num_leaves, 1);
383  if (it != dict.end()) {
384  CHECK(!it->second.empty())
385  << "Ill-formed LightGBM model file: split_gain cannot be empty string";
386  tree.split_gain = TextToArray<float>(it->second, tree.num_leaves - 1);
387  } else {
388  tree.split_gain = std::vector<float>();
389  }
390  }
391 
392  it = dict.find("internal_count");
393  if (tree.num_leaves <= 1) {
394  tree.internal_count = std::vector<int>();
395  } else {
396  CHECK_GT(tree.num_leaves, 1);
397  if (it != dict.end()) {
398  CHECK(!it->second.empty())
399  << "Ill-formed LightGBM model file: internal_count cannot be empty string";
400  tree.internal_count = TextToArray<int>(it->second, tree.num_leaves - 1);
401  } else {
402  tree.internal_count = std::vector<int>();
403  }
404  }
405 
406  it = dict.find("leaf_count");
407  if (tree.num_leaves == 0) {
408  tree.leaf_count = std::vector<int>();
409  } else {
410  CHECK_GT(tree.num_leaves, 0);
411  if (it != dict.end() && !it->second.empty()) {
412  tree.leaf_count = TextToArray<int>(it->second, tree.num_leaves);
413  } else {
414  tree.leaf_count = std::vector<int>();
415  }
416  }
417 
418  it = dict.find("left_child");
419  if (tree.num_leaves <= 1) {
420  tree.left_child = std::vector<int>();
421  } else {
422  CHECK_GT(tree.num_leaves, 1);
423  CHECK(it != dict.end() && !it->second.empty())
424  << "Ill-formed LightGBM model file: need left_child";
425  tree.left_child = TextToArray<int>(it->second, tree.num_leaves - 1);
426  }
427 
428  it = dict.find("right_child");
429  if (tree.num_leaves <= 1) {
430  tree.right_child = std::vector<int>();
431  } else {
432  CHECK_GT(tree.num_leaves, 1);
433  CHECK(it != dict.end() && !it->second.empty())
434  << "Ill-formed LightGBM model file: need right_child";
435  tree.right_child = TextToArray<int>(it->second, tree.num_leaves - 1);
436  }
437  }
438 
439  /* 2. Export model */
440  std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<double, double>();
441  auto* model = dynamic_cast<treelite::ModelImpl<double, double>*>(model_ptr.get());
442  model->num_feature = max_feature_idx_ + 1;
443  model->average_tree_output = average_output_;
444  if (num_class_ > 1) {
445  // multi-class classifier
446  model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
447  model->task_param.grove_per_class = true;
448  } else {
449  // binary classifier or regressor
450  model->task_type = treelite::TaskType::kBinaryClfRegr;
451  model->task_param.grove_per_class = false;
452  }
453  model->task_param.output_type = treelite::TaskParameter::OutputType::kFloat;
454  model->task_param.num_class = num_class_;
455  model->task_param.leaf_vector_size = 1;
456 
457  // set correct prediction transform function, depending on objective function
458  if (obj_name_ == "multiclass") {
459  // validate num_class parameter
460  int num_class = -1;
461  int tmp;
462  for (const auto& str : obj_param_) {
463  auto tokens = Split(str, ':');
464  if (tokens.size() == 2 && tokens[0] == "num_class"
465  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
466  num_class = tmp;
467  break;
468  }
469  }
470  CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class)
471  << "Ill-formed LightGBM model file: not a valid multiclass objective";
472 
473  std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform));
474  } else if (obj_name_ == "multiclassova") {
475  // validate num_class and alpha parameters
476  int num_class = -1;
477  float alpha = -1.0f;
478  int tmp;
479  float tmp2;
480  for (const auto& str : obj_param_) {
481  auto tokens = Split(str, ':');
482  if (tokens.size() == 2) {
483  if (tokens[0] == "num_class"
484  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
485  num_class = tmp;
486  } else if (tokens[0] == "sigmoid"
487  && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
488  alpha = tmp2;
489  }
490  }
491  }
492  CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class
493  && alpha > 0.0f)
494  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
495 
496  std::strncpy(model->param.pred_transform, "multiclass_ova",
497  sizeof(model->param.pred_transform));
498  model->param.sigmoid_alpha = alpha;
499  } else if (obj_name_ == "binary") {
500  // validate alpha parameter
501  float alpha = -1.0f;
502  float tmp;
503  for (const auto& str : obj_param_) {
504  auto tokens = Split(str, ':');
505  if (tokens.size() == 2 && tokens[0] == "sigmoid"
506  && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
507  alpha = tmp;
508  break;
509  }
510  }
511  CHECK_GT(alpha, 0.0f)
512  << "Ill-formed LightGBM model file: not a valid binary objective";
513 
514  std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform));
515  model->param.sigmoid_alpha = alpha;
516  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
517  std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform));
518  model->param.sigmoid_alpha = 1.0f;
519  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
520  std::strncpy(model->param.pred_transform, "logarithm_one_plus_exp",
521  sizeof(model->param.pred_transform));
522  } else if (obj_name_ == "poisson" || obj_name_ == "gamma" || obj_name_ == "tweedie") {
523  std::strncpy(model->param.pred_transform, "exponential",
524  sizeof(model->param.pred_transform));
525  } else if (obj_name_ == "regression" || obj_name_ == "regression_l1" || obj_name_ == "huber"
526  || obj_name_ == "fair" || obj_name_ == "quantile" || obj_name_ == "mape") {
527  // Regression family
528  bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(), "sqrt") != obj_param_.cend());
529  if (sqrt) {
530  std::strncpy(model->param.pred_transform, "signed_square",
531  sizeof(model->param.pred_transform));
532  } else {
533  std::strncpy(model->param.pred_transform, "identity",
534  sizeof(model->param.pred_transform));
535  }
536  } else if (obj_name_ == "lambdarank" || obj_name_ == "rank_xendcg" || obj_name_ == "custom") {
537  // Ranking family, or a custom user-defined objective
538  std::strncpy(model->param.pred_transform, "identity",
539  sizeof(model->param.pred_transform));
540  } else {
541  LOG(FATAL) << "Unrecognized objective: " << obj_name_;
542  }
543 
544  // traverse trees
545  for (const auto& lgb_tree : lgb_trees_) {
546  model->trees.emplace_back();
547  treelite::Tree<double, double>& tree = model->trees.back();
548  tree.Init();
549 
550  // assign node ID's so that a breadth-wise traversal would yield
551  // the monotonic sequence 0, 1, 2, ...
552  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
553  if (lgb_tree.num_leaves == 0) {
554  continue;
555  } else if (lgb_tree.num_leaves == 1) {
556  // A constant-value tree with a single root node that's also a leaf
557  Q.push({-1, 0});
558  } else {
559  Q.push({0, 0});
560  }
561  while (!Q.empty()) {
562  int old_id, new_id;
563  std::tie(old_id, new_id) = Q.front(); Q.pop();
564  if (old_id < 0) { // leaf
565  const double leaf_value = lgb_tree.leaf_value[~old_id];
566  tree.SetLeaf(new_id, static_cast<double>(leaf_value));
567  if (!lgb_tree.leaf_count.empty()) {
568  const int data_count = lgb_tree.leaf_count[~old_id];
569  CHECK_GE(data_count, 0);
570  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
571  }
572  } else { // non-leaf
573  const auto split_index = static_cast<unsigned>(lgb_tree.split_feature[old_id]);
574  const auto missing_type = GetMissingType(lgb_tree.decision_type[old_id]);
575 
576  tree.AddChilds(new_id);
577  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
578  // categorical
579  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
580  const std::vector<uint32_t> left_categories
581  = BitsetToList(lgb_tree.cat_threshold.data()
582  + lgb_tree.cat_boundaries[cat_idx],
583  lgb_tree.cat_boundaries[cat_idx + 1]
584  - lgb_tree.cat_boundaries[cat_idx]);
585  const bool missing_value_to_zero = missing_type != MissingType::kNaN;
586  bool default_left = false;
587  if (missing_value_to_zero) {
588  // If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
589  // we need to override the default_left flag
590  default_left
591  = (std::find(left_categories.begin(), left_categories.end(),
592  static_cast<uint32_t>(0)) != left_categories.end());
593  }
594  tree.SetCategoricalSplit(new_id, split_index, default_left, left_categories, false);
595  } else {
596  // numerical
597  const auto threshold = static_cast<double>(lgb_tree.threshold[old_id]);
598  bool default_left
599  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
600  const treelite::Operator cmp_op = treelite::Operator::kLE;
601  const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
602  if (missing_value_to_zero) {
603  // If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
604  // we need to override the default_left flag
605  default_left = 0.0 <= threshold;
606  }
607  tree.SetNumericalSplit(new_id, split_index, threshold, default_left, cmp_op);
608  }
609  if (!lgb_tree.internal_count.empty()) {
610  const int data_count = lgb_tree.internal_count[old_id];
611  CHECK_GE(data_count, 0);
612  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
613  }
614  if (!lgb_tree.split_gain.empty()) {
615  tree.SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
616  }
617  Q.push({lgb_tree.left_child[old_id], tree.LeftChild(new_id)});
618  Q.push({lgb_tree.right_child[old_id], tree.RightChild(new_id)});
619  }
620  }
621  }
622  return model_ptr;
623 }
624 
625 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:627
model structure for tree ensemble
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:26
in-memory representation of a decision tree
Definition: tree.h:197
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:560
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:550
int LeftChild(int nid) const
Getters.
Definition: tree.h:326
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:333
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:640
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
Definition: tree_impl.h:694
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:673
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:678
Operator
comparison operators
Definition: base.h:26
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:728