Treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <treelite/logging.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <unordered_map>
12 #include <limits>
13 #include <queue>
14 #include <string>
15 #include <fstream>
16 
17 namespace {
18 
19 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
20 
21 } // anonymous namespace
22 
23 namespace treelite {
24 namespace frontend {
25 
26 
27 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char* filename) {
28  std::ifstream fi(filename, std::ios::in);
29  return ParseStream(fi);
30 }
31 
32 } // namespace frontend
33 } // namespace treelite
34 
35 /* auxiliary data structures to interpret lightgbm model file */
36 namespace {
37 
38 template <typename T>
39 inline T TextToNumber(const std::string& str) {
40  static_assert(std::is_same<T, float>::value
41  || std::is_same<T, double>::value
42  || std::is_same<T, int>::value
43  || std::is_same<T, int8_t>::value
44  || std::is_same<T, uint32_t>::value
45  || std::is_same<T, uint64_t>::value,
46  "unsupported data type for TextToNumber; use float, double, "
47  "int, int8_t, uint32_t, or uint64_t");
48 }
49 
50 template <>
51 inline float TextToNumber(const std::string& str) {
52  errno = 0;
53  char *endptr;
54  float val = std::strtof(str.c_str(), &endptr);
55  if (errno == ERANGE) {
56  TREELITE_LOG(FATAL) << "Range error while converting string to double";
57  } else if (errno != 0) {
58  TREELITE_LOG(FATAL) << "Unknown error";
59  } else if (*endptr != '\0') {
60  TREELITE_LOG(FATAL) << "String does not represent a valid floating-point number";
61  }
62  return val;
63 }
64 
65 template <>
66 inline double TextToNumber(const std::string& str) {
67  errno = 0;
68  char *endptr;
69  double val = std::strtod(str.c_str(), &endptr);
70  if (errno == ERANGE) {
71  TREELITE_LOG(FATAL) << "Range error while converting string to double";
72  } else if (errno != 0) {
73  TREELITE_LOG(FATAL) << "Unknown error";
74  } else if (*endptr != '\0') {
75  TREELITE_LOG(FATAL) << "String does not represent a valid floating-point number";
76  }
77  return val;
78 }
79 
80 template <>
81 inline int TextToNumber(const std::string& str) {
82  errno = 0;
83  char *endptr;
84  auto val = std::strtol(str.c_str(), &endptr, 10);
85  if (errno == ERANGE || val < std::numeric_limits<int>::min()
86  || val > std::numeric_limits<int>::max()) {
87  TREELITE_LOG(FATAL) << "Range error while converting string to int";
88  } else if (errno != 0) {
89  TREELITE_LOG(FATAL) << "Unknown error";
90  } else if (*endptr != '\0') {
91  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
92  }
93  return static_cast<int>(val);
94 }
95 
96 template <>
97 inline int8_t TextToNumber(const std::string& str) {
98  errno = 0;
99  char *endptr;
100  auto val = std::strtol(str.c_str(), &endptr, 10);
101  if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
102  || val > std::numeric_limits<int8_t>::max()) {
103  TREELITE_LOG(FATAL) << "Range error while converting string to int8_t";
104  } else if (errno != 0) {
105  TREELITE_LOG(FATAL) << "Unknown error";
106  } else if (*endptr != '\0') {
107  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
108  }
109  return static_cast<int8_t>(val);
110 }
111 
112 template <>
113 inline uint32_t TextToNumber(const std::string& str) {
114  errno = 0;
115  char *endptr;
116  auto val = std::strtoul(str.c_str(), &endptr, 10);
117  if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
118  TREELITE_LOG(FATAL) << "Range error while converting string to uint32_t";
119  } else if (errno != 0) {
120  TREELITE_LOG(FATAL) << "Unknown error";
121  } else if (*endptr != '\0') {
122  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
123  }
124  return static_cast<uint32_t>(val);
125 }
126 
127 template <>
128 inline uint64_t TextToNumber(const std::string& str) {
129  errno = 0;
130  char *endptr;
131  auto val = std::strtoull(str.c_str(), &endptr, 10);
132  if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
133  TREELITE_LOG(FATAL) << "Range error while converting string to uint64_t";
134  } else if (errno != 0) {
135  TREELITE_LOG(FATAL) << "Unknown error";
136  } else if (*endptr != '\0') {
137  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
138  }
139  return static_cast<uint64_t>(val);
140 }
141 
142 inline std::vector<std::string> Split(const std::string& text, char delim) {
143  std::vector<std::string> array;
144  std::istringstream ss(text);
145  std::string token;
146  while (std::getline(ss, token, delim)) {
147  array.push_back(token);
148  }
149  return array;
150 }
151 
152 template <typename T>
153 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
154  if (text.empty() && num_entry > 0) {
155  TREELITE_LOG(FATAL) << "Cannot convert empty text into array";
156  }
157  std::vector<T> array;
158  std::istringstream ss(text);
159  std::string token;
160  for (int i = 0; i < num_entry; ++i) {
161  std::getline(ss, token, ' ');
162  array.push_back(TextToNumber<T>(token));
163  }
164  return array;
165 }
166 
167 enum Masks : uint8_t {
168  kCategoricalMask = 1,
169  kDefaultLeftMask = 2
170 };
171 
172 enum class MissingType : uint8_t {
173  kNone,
174  kZero,
175  kNaN
176 };
177 
178 struct LGBTree {
179  int num_leaves;
180  int num_cat; // number of categorical splits
181  std::vector<double> leaf_value;
182  std::vector<int8_t> decision_type;
183  std::vector<uint64_t> cat_boundaries;
184  std::vector<uint32_t> cat_threshold;
185  std::vector<int> split_feature;
186  std::vector<double> threshold;
187  std::vector<int> left_child;
188  std::vector<int> right_child;
189  std::vector<float> split_gain;
190  std::vector<int> internal_count;
191  std::vector<int> leaf_count;
192 };
193 
194 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
195  return (decision_type & mask) > 0;
196 }
197 
198 inline MissingType GetMissingType(int8_t decision_type) {
199  return static_cast<MissingType>((decision_type >> 2) & 3);
200 }
201 
202 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
203  size_t nslots) {
204  std::vector<uint32_t> result;
205  const size_t nbits = nslots * 32;
206  for (size_t i = 0; i < nbits; ++i) {
207  const size_t i1 = i / 32;
208  const uint32_t i2 = static_cast<uint32_t>(i % 32);
209  if ((bits[i1] >> i2) & 1) {
210  result.push_back(static_cast<uint32_t>(i));
211  }
212  }
213  return result;
214 }
215 
216 inline std::vector<std::string> LoadText(std::istream& fi) {
217  std::vector<std::string> lines;
218  std::string line;
219  while (std::getline(fi, line)) {
220  lines.push_back(line);
221  }
222 
223  return lines;
224 }
225 
226 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
227  std::vector<LGBTree> lgb_trees_;
228  int max_feature_idx_;
229  int num_class_;
230  bool average_output_;
231  std::string obj_name_;
232  std::vector<std::string> obj_param_;
233 
234  /* 1. Parse input stream */
235  std::vector<std::string> lines = LoadText(fi);
236  std::unordered_map<std::string, std::string> global_dict;
237  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
238 
239  bool in_tree = false; // is current entry part of a tree?
240  for (const auto& line : lines) {
241  std::istringstream ss(line);
242  std::string key, value, rest;
243  std::getline(ss, key, '=');
244  std::getline(ss, value, '=');
245  std::getline(ss, rest);
246  if (!rest.empty()) {
247  value += "=";
248  value += rest;
249  }
250  if (key == "Tree") {
251  in_tree = true;
252  tree_dict.emplace_back();
253  } else {
254  if (in_tree) {
255  tree_dict.back()[key] = value;
256  } else {
257  global_dict[key] = value;
258  }
259  }
260  }
261 
262  {
263  auto it = global_dict.find("objective");
264  if (it == global_dict.end()) { // custom objective (fobj)
265  obj_name_ = "custom";
266  } else {
267  auto obj_strs = Split(it->second, ' ');
268  obj_name_ = obj_strs[0];
269  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
270  }
271 
272  it = global_dict.find("max_feature_idx");
273  TREELITE_CHECK(it != global_dict.end())
274  << "Ill-formed LightGBM model file: need max_feature_idx";
275  max_feature_idx_ = TextToNumber<int>(it->second);
276  it = global_dict.find("num_class");
277  TREELITE_CHECK(it != global_dict.end())
278  << "Ill-formed LightGBM model file: need num_class";
279  num_class_ = TextToNumber<int>(it->second);
280 
281  it = global_dict.find("average_output");
282  average_output_ = (it != global_dict.end());
283  }
284 
285  for (const auto& dict : tree_dict) {
286  lgb_trees_.emplace_back();
287  LGBTree& tree = lgb_trees_.back();
288 
289  auto it = dict.find("num_leaves");
290  TREELITE_CHECK(it != dict.end())
291  << "Ill-formed LightGBM model file: need num_leaves";
292  tree.num_leaves = TextToNumber<int>(it->second);
293 
294  it = dict.find("num_cat");
295  TREELITE_CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
296  tree.num_cat = TextToNumber<int>(it->second);
297 
298  it = dict.find("leaf_value");
299  TREELITE_CHECK(it != dict.end() && !it->second.empty())
300  << "Ill-formed LightGBM model file: need leaf_value";
301  tree.leaf_value
302  = TextToArray<double>(it->second, tree.num_leaves);
303 
304  it = dict.find("decision_type");
305  if (tree.num_leaves <= 1) {
306  tree.decision_type = std::vector<int8_t>();
307  } else {
308  TREELITE_CHECK_GT(tree.num_leaves, 1);
309  if (it == dict.end()) {
310  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
311  } else {
312  TREELITE_CHECK(!it->second.empty())
313  << "Ill-formed LightGBM model file: decision_type cannot be empty string";
314  tree.decision_type = TextToArray<int8_t>(it->second, tree.num_leaves - 1);
315  }
316  }
317 
318  if (tree.num_cat > 0) {
319  it = dict.find("cat_boundaries");
320  TREELITE_CHECK(it != dict.end() && !it->second.empty())
321  << "Ill-formed LightGBM model file: need cat_boundaries";
322  tree.cat_boundaries
323  = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
324  it = dict.find("cat_threshold");
325  TREELITE_CHECK(it != dict.end() && !it->second.empty())
326  << "Ill-formed LightGBM model file: need cat_threshold";
327  tree.cat_threshold
328  = TextToArray<uint32_t>(it->second, static_cast<uint32_t>(tree.cat_boundaries.back()));
329  }
330 
331  it = dict.find("split_feature");
332  if (tree.num_leaves <= 1) {
333  tree.split_feature = std::vector<int>();
334  } else {
335  TREELITE_CHECK_GT(tree.num_leaves, 1);
336  TREELITE_CHECK(it != dict.end() && !it->second.empty())
337  << "Ill-formed LightGBM model file: need split_feature";
338  tree.split_feature = TextToArray<int>(it->second, tree.num_leaves - 1);
339  }
340 
341  it = dict.find("threshold");
342  if (tree.num_leaves <= 1) {
343  tree.threshold = std::vector<double>();
344  } else {
345  TREELITE_CHECK_GT(tree.num_leaves, 1);
346  TREELITE_CHECK(it != dict.end() && !it->second.empty())
347  << "Ill-formed LightGBM model file: need threshold";
348  tree.threshold = TextToArray<double>(it->second, tree.num_leaves - 1);
349  }
350 
351  it = dict.find("split_gain");
352  if (tree.num_leaves <= 1) {
353  tree.split_gain = std::vector<float>();
354  } else {
355  TREELITE_CHECK_GT(tree.num_leaves, 1);
356  if (it != dict.end()) {
357  TREELITE_CHECK(!it->second.empty())
358  << "Ill-formed LightGBM model file: split_gain cannot be empty string";
359  tree.split_gain = TextToArray<float>(it->second, tree.num_leaves - 1);
360  } else {
361  tree.split_gain = std::vector<float>();
362  }
363  }
364 
365  it = dict.find("internal_count");
366  if (tree.num_leaves <= 1) {
367  tree.internal_count = std::vector<int>();
368  } else {
369  TREELITE_CHECK_GT(tree.num_leaves, 1);
370  if (it != dict.end()) {
371  TREELITE_CHECK(!it->second.empty())
372  << "Ill-formed LightGBM model file: internal_count cannot be empty string";
373  tree.internal_count = TextToArray<int>(it->second, tree.num_leaves - 1);
374  } else {
375  tree.internal_count = std::vector<int>();
376  }
377  }
378 
379  it = dict.find("leaf_count");
380  if (tree.num_leaves == 0) {
381  tree.leaf_count = std::vector<int>();
382  } else {
383  TREELITE_CHECK_GT(tree.num_leaves, 0);
384  if (it != dict.end() && !it->second.empty()) {
385  tree.leaf_count = TextToArray<int>(it->second, tree.num_leaves);
386  } else {
387  tree.leaf_count = std::vector<int>();
388  }
389  }
390 
391  it = dict.find("left_child");
392  if (tree.num_leaves <= 1) {
393  tree.left_child = std::vector<int>();
394  } else {
395  TREELITE_CHECK_GT(tree.num_leaves, 1);
396  TREELITE_CHECK(it != dict.end() && !it->second.empty())
397  << "Ill-formed LightGBM model file: need left_child";
398  tree.left_child = TextToArray<int>(it->second, tree.num_leaves - 1);
399  }
400 
401  it = dict.find("right_child");
402  if (tree.num_leaves <= 1) {
403  tree.right_child = std::vector<int>();
404  } else {
405  TREELITE_CHECK_GT(tree.num_leaves, 1);
406  TREELITE_CHECK(it != dict.end() && !it->second.empty())
407  << "Ill-formed LightGBM model file: need right_child";
408  tree.right_child = TextToArray<int>(it->second, tree.num_leaves - 1);
409  }
410  }
411 
412  /* 2. Export model */
413  std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<double, double>();
414  auto* model = dynamic_cast<treelite::ModelImpl<double, double>*>(model_ptr.get());
415  model->num_feature = max_feature_idx_ + 1;
416  model->average_tree_output = average_output_;
417  if (num_class_ > 1) {
418  // multi-class classifier
419  model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
420  model->task_param.grove_per_class = true;
421  } else {
422  // binary classifier or regressor
423  model->task_type = treelite::TaskType::kBinaryClfRegr;
424  model->task_param.grove_per_class = false;
425  }
426  model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
427  model->task_param.num_class = num_class_;
428  model->task_param.leaf_vector_size = 1;
429 
430  // set correct prediction transform function, depending on objective function
431  if (obj_name_ == "multiclass") {
432  // validate num_class parameter
433  int num_class = -1;
434  int tmp;
435  for (const auto& str : obj_param_) {
436  auto tokens = Split(str, ':');
437  if (tokens.size() == 2 && tokens[0] == "num_class"
438  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
439  num_class = tmp;
440  break;
441  }
442  }
443  TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class)
444  << "Ill-formed LightGBM model file: not a valid multiclass objective";
445 
446  std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform));
447  } else if (obj_name_ == "multiclassova") {
448  // validate num_class and alpha parameters
449  int num_class = -1;
450  float alpha = -1.0f;
451  int tmp;
452  float tmp2;
453  for (const auto& str : obj_param_) {
454  auto tokens = Split(str, ':');
455  if (tokens.size() == 2) {
456  if (tokens[0] == "num_class"
457  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
458  num_class = tmp;
459  } else if (tokens[0] == "sigmoid"
460  && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
461  alpha = tmp2;
462  }
463  }
464  }
465  TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class
466  && alpha > 0.0f)
467  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
468 
469  std::strncpy(model->param.pred_transform, "multiclass_ova",
470  sizeof(model->param.pred_transform));
471  model->param.sigmoid_alpha = alpha;
472  } else if (obj_name_ == "binary") {
473  // validate alpha parameter
474  float alpha = -1.0f;
475  float tmp;
476  for (const auto& str : obj_param_) {
477  auto tokens = Split(str, ':');
478  if (tokens.size() == 2 && tokens[0] == "sigmoid"
479  && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
480  alpha = tmp;
481  break;
482  }
483  }
484  TREELITE_CHECK_GT(alpha, 0.0f)
485  << "Ill-formed LightGBM model file: not a valid binary objective";
486 
487  std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform));
488  model->param.sigmoid_alpha = alpha;
489  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
490  std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform));
491  model->param.sigmoid_alpha = 1.0f;
492  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
493  std::strncpy(model->param.pred_transform, "logarithm_one_plus_exp",
494  sizeof(model->param.pred_transform));
495  } else if (obj_name_ == "poisson" || obj_name_ == "gamma" || obj_name_ == "tweedie") {
496  std::strncpy(model->param.pred_transform, "exponential",
497  sizeof(model->param.pred_transform));
498  } else if (obj_name_ == "regression" || obj_name_ == "regression_l1" || obj_name_ == "huber"
499  || obj_name_ == "fair" || obj_name_ == "quantile" || obj_name_ == "mape") {
500  // Regression family
501  bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(), "sqrt") != obj_param_.cend());
502  if (sqrt) {
503  std::strncpy(model->param.pred_transform, "signed_square",
504  sizeof(model->param.pred_transform));
505  } else {
506  std::strncpy(model->param.pred_transform, "identity",
507  sizeof(model->param.pred_transform));
508  }
509  } else if (obj_name_ == "lambdarank" || obj_name_ == "rank_xendcg" || obj_name_ == "custom") {
510  // Ranking family, or a custom user-defined objective
511  std::strncpy(model->param.pred_transform, "identity",
512  sizeof(model->param.pred_transform));
513  } else {
514  TREELITE_LOG(FATAL) << "Unrecognized objective: " << obj_name_;
515  }
516 
517  // traverse trees
518  for (const auto& lgb_tree : lgb_trees_) {
519  model->trees.emplace_back();
520  treelite::Tree<double, double>& tree = model->trees.back();
521  tree.Init();
522 
523  // assign node ID's so that a breadth-wise traversal would yield
524  // the monotonic sequence 0, 1, 2, ...
525  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
526  if (lgb_tree.num_leaves == 0) {
527  continue;
528  } else if (lgb_tree.num_leaves == 1) {
529  // A constant-value tree with a single root node that's also a leaf
530  Q.push({-1, 0});
531  } else {
532  Q.push({0, 0});
533  }
534  while (!Q.empty()) {
535  int old_id, new_id;
536  std::tie(old_id, new_id) = Q.front(); Q.pop();
537  if (old_id < 0) { // leaf
538  const double leaf_value = lgb_tree.leaf_value[~old_id];
539  tree.SetLeaf(new_id, static_cast<double>(leaf_value));
540  if (!lgb_tree.leaf_count.empty()) {
541  const int data_count = lgb_tree.leaf_count[~old_id];
542  TREELITE_CHECK_GE(data_count, 0);
543  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
544  }
545  } else { // non-leaf
546  const auto split_index = static_cast<unsigned>(lgb_tree.split_feature[old_id]);
547  const auto missing_type = GetMissingType(lgb_tree.decision_type[old_id]);
548 
549  tree.AddChilds(new_id);
550  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
551  // categorical
552  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
553  const std::vector<uint32_t> left_categories
554  = BitsetToList(lgb_tree.cat_threshold.data()
555  + lgb_tree.cat_boundaries[cat_idx],
556  lgb_tree.cat_boundaries[cat_idx + 1]
557  - lgb_tree.cat_boundaries[cat_idx]);
558  const bool missing_value_to_zero = missing_type != MissingType::kNaN;
559  bool default_left = false;
560  if (missing_value_to_zero) {
561  // If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
562  // we need to override the default_left flag
563  default_left
564  = (std::find(left_categories.begin(), left_categories.end(),
565  static_cast<uint32_t>(0)) != left_categories.end());
566  }
567  tree.SetCategoricalSplit(new_id, split_index, default_left, left_categories, false);
568  } else {
569  // numerical
570  const auto threshold = static_cast<double>(lgb_tree.threshold[old_id]);
571  bool default_left
572  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
573  const treelite::Operator cmp_op = treelite::Operator::kLE;
574  const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
575  if (missing_value_to_zero) {
576  // If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
577  // we need to override the default_left flag
578  default_left = 0.0 <= threshold;
579  }
580  tree.SetNumericalSplit(new_id, split_index, threshold, default_left, cmp_op);
581  }
582  if (!lgb_tree.internal_count.empty()) {
583  const int data_count = lgb_tree.internal_count[old_id];
584  TREELITE_CHECK_GE(data_count, 0);
585  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
586  }
587  if (!lgb_tree.split_gain.empty()) {
588  tree.SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
589  }
590  Q.push({lgb_tree.left_child[old_id], tree.LeftChild(new_id)});
591  Q.push({lgb_tree.right_child[old_id], tree.RightChild(new_id)});
592  }
593  }
594  }
595  return model_ptr;
596 }
597 
598 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:627
model structure for tree ensemble
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:27
in-memory representation of a decision tree
Definition: tree.h:190
logging facility for Treelite
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:556
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:546
int LeftChild(int nid) const
Getters.
Definition: tree.h:322
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:329
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:640
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
Definition: tree_impl.h:694
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:667
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:678
Operator
comparison operators
Definition: base.h:26
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:728