Treelite
lightgbm.cc
Go to the documentation of this file.
1 
8 #include <treelite/logging.h>
9 #include <treelite/frontend.h>
10 #include <treelite/tree.h>
11 #include <unordered_map>
12 #include <limits>
13 #include <queue>
14 #include <string>
15 #include <fstream>
16 #include <sstream>
17 
18 namespace {
19 
20 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi);
21 
22 } // anonymous namespace
23 
24 namespace treelite {
25 namespace frontend {
26 
27 
28 std::unique_ptr<treelite::Model> LoadLightGBMModel(const char* filename) {
29  std::ifstream fi(filename, std::ios::in);
30  return ParseStream(fi);
31 }
32 
33 std::unique_ptr<treelite::Model> LoadLightGBMModelFromString(const char* model_str) {
34  std::istringstream is(model_str);
35  return ParseStream(is);
36 }
37 
38 } // namespace frontend
39 } // namespace treelite
40 
41 /* auxiliary data structures to interpret lightgbm model file */
42 namespace {
43 
44 template <typename T>
45 inline T TextToNumber(const std::string& str) {
46  static_assert(std::is_same<T, float>::value
47  || std::is_same<T, double>::value
48  || std::is_same<T, int>::value
49  || std::is_same<T, int8_t>::value
50  || std::is_same<T, uint32_t>::value
51  || std::is_same<T, uint64_t>::value,
52  "unsupported data type for TextToNumber; use float, double, "
53  "int, int8_t, uint32_t, or uint64_t");
54 }
55 
56 template <>
57 inline float TextToNumber(const std::string& str) {
58  errno = 0;
59  char *endptr;
60  float val = std::strtof(str.c_str(), &endptr);
61  if (errno == ERANGE) {
62  TREELITE_LOG(FATAL) << "Range error while converting string to double";
63  } else if (errno != 0) {
64  TREELITE_LOG(FATAL) << "Unknown error";
65  } else if (*endptr != '\0') {
66  TREELITE_LOG(FATAL) << "String does not represent a valid floating-point number";
67  }
68  return val;
69 }
70 
71 template <>
72 inline double TextToNumber(const std::string& str) {
73  errno = 0;
74  char *endptr;
75  double val = std::strtod(str.c_str(), &endptr);
76  if (errno == ERANGE) {
77  TREELITE_LOG(FATAL) << "Range error while converting string to double";
78  } else if (errno != 0) {
79  TREELITE_LOG(FATAL) << "Unknown error";
80  } else if (*endptr != '\0') {
81  TREELITE_LOG(FATAL) << "String does not represent a valid floating-point number";
82  }
83  return val;
84 }
85 
86 template <>
87 inline int TextToNumber(const std::string& str) {
88  errno = 0;
89  char *endptr;
90  auto val = std::strtol(str.c_str(), &endptr, 10);
91  if (errno == ERANGE || val < std::numeric_limits<int>::min()
92  || val > std::numeric_limits<int>::max()) {
93  TREELITE_LOG(FATAL) << "Range error while converting string to int";
94  } else if (errno != 0) {
95  TREELITE_LOG(FATAL) << "Unknown error";
96  } else if (*endptr != '\0') {
97  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
98  }
99  return static_cast<int>(val);
100 }
101 
102 template <>
103 inline int8_t TextToNumber(const std::string& str) {
104  errno = 0;
105  char *endptr;
106  auto val = std::strtol(str.c_str(), &endptr, 10);
107  if (errno == ERANGE || val < std::numeric_limits<int8_t>::min()
108  || val > std::numeric_limits<int8_t>::max()) {
109  TREELITE_LOG(FATAL) << "Range error while converting string to int8_t";
110  } else if (errno != 0) {
111  TREELITE_LOG(FATAL) << "Unknown error";
112  } else if (*endptr != '\0') {
113  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
114  }
115  return static_cast<int8_t>(val);
116 }
117 
118 template <>
119 inline uint32_t TextToNumber(const std::string& str) {
120  errno = 0;
121  char *endptr;
122  auto val = std::strtoul(str.c_str(), &endptr, 10);
123  if (errno == ERANGE || val > std::numeric_limits<uint32_t>::max()) {
124  TREELITE_LOG(FATAL) << "Range error while converting string to uint32_t";
125  } else if (errno != 0) {
126  TREELITE_LOG(FATAL) << "Unknown error";
127  } else if (*endptr != '\0') {
128  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
129  }
130  return static_cast<uint32_t>(val);
131 }
132 
133 template <>
134 inline uint64_t TextToNumber(const std::string& str) {
135  errno = 0;
136  char *endptr;
137  auto val = std::strtoull(str.c_str(), &endptr, 10);
138  if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
139  TREELITE_LOG(FATAL) << "Range error while converting string to uint64_t";
140  } else if (errno != 0) {
141  TREELITE_LOG(FATAL) << "Unknown error";
142  } else if (*endptr != '\0') {
143  TREELITE_LOG(FATAL) << "String does not represent a valid integer";
144  }
145  return static_cast<uint64_t>(val);
146 }
147 
148 inline std::vector<std::string> Split(const std::string& text, char delim) {
149  std::vector<std::string> array;
150  std::istringstream ss(text);
151  std::string token;
152  while (std::getline(ss, token, delim)) {
153  array.push_back(token);
154  }
155  return array;
156 }
157 
158 template <typename T>
159 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
160  if (text.empty() && num_entry > 0) {
161  TREELITE_LOG(FATAL) << "Cannot convert empty text into array";
162  }
163  std::vector<T> array;
164  std::istringstream ss(text);
165  std::string token;
166  for (int i = 0; i < num_entry; ++i) {
167  std::getline(ss, token, ' ');
168  array.push_back(TextToNumber<T>(token));
169  }
170  return array;
171 }
172 
173 enum Masks : uint8_t {
174  kCategoricalMask = 1,
175  kDefaultLeftMask = 2
176 };
177 
178 enum class MissingType : uint8_t {
179  kNone,
180  kZero,
181  kNaN
182 };
183 
184 struct LGBTree {
185  int num_leaves;
186  int num_cat; // number of categorical splits
187  std::vector<double> leaf_value;
188  std::vector<int8_t> decision_type;
189  std::vector<uint64_t> cat_boundaries;
190  std::vector<uint32_t> cat_threshold;
191  std::vector<int> split_feature;
192  std::vector<double> threshold;
193  std::vector<int> left_child;
194  std::vector<int> right_child;
195  std::vector<float> split_gain;
196  std::vector<int> internal_count;
197  std::vector<int> leaf_count;
198 };
199 
200 inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
201  return (decision_type & mask) > 0;
202 }
203 
204 inline MissingType GetMissingType(int8_t decision_type) {
205  return static_cast<MissingType>((decision_type >> 2) & 3);
206 }
207 
208 inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
209  size_t nslots) {
210  std::vector<uint32_t> result;
211  const size_t nbits = nslots * 32;
212  for (size_t i = 0; i < nbits; ++i) {
213  const size_t i1 = i / 32;
214  const uint32_t i2 = static_cast<uint32_t>(i % 32);
215  if ((bits[i1] >> i2) & 1) {
216  result.push_back(static_cast<uint32_t>(i));
217  }
218  }
219  return result;
220 }
221 
222 inline std::vector<std::string> LoadText(std::istream& fi) {
223  std::vector<std::string> lines;
224  std::string line;
225  while (std::getline(fi, line)) {
226  lines.push_back(line);
227  }
228 
229  return lines;
230 }
231 
232 inline std::unique_ptr<treelite::Model> ParseStream(std::istream& fi) {
233  std::vector<LGBTree> lgb_trees_;
234  int max_feature_idx_;
235  int num_class_;
236  bool average_output_;
237  std::string obj_name_;
238  std::vector<std::string> obj_param_;
239 
240  /* 1. Parse input stream */
241  std::vector<std::string> lines = LoadText(fi);
242  std::unordered_map<std::string, std::string> global_dict;
243  std::vector<std::unordered_map<std::string, std::string>> tree_dict;
244 
245  bool in_tree = false; // is current entry part of a tree?
246  for (const auto& line : lines) {
247  std::istringstream ss(line);
248  std::string key, value, rest;
249  std::getline(ss, key, '=');
250  std::getline(ss, value, '=');
251  std::getline(ss, rest);
252  if (!rest.empty()) {
253  value += "=";
254  value += rest;
255  }
256  if (key == "Tree") {
257  in_tree = true;
258  tree_dict.emplace_back();
259  } else {
260  if (in_tree) {
261  tree_dict.back()[key] = value;
262  } else {
263  global_dict[key] = value;
264  }
265  }
266  }
267 
268  {
269  auto it = global_dict.find("objective");
270  if (it == global_dict.end()) { // custom objective (fobj)
271  obj_name_ = "custom";
272  } else {
273  auto obj_strs = Split(it->second, ' ');
274  obj_name_ = obj_strs[0];
275  obj_param_ = std::vector<std::string>(obj_strs.begin() + 1, obj_strs.end());
276  }
277 
278  it = global_dict.find("max_feature_idx");
279  TREELITE_CHECK(it != global_dict.end())
280  << "Ill-formed LightGBM model file: need max_feature_idx";
281  max_feature_idx_ = TextToNumber<int>(it->second);
282  it = global_dict.find("num_class");
283  TREELITE_CHECK(it != global_dict.end())
284  << "Ill-formed LightGBM model file: need num_class";
285  num_class_ = TextToNumber<int>(it->second);
286 
287  it = global_dict.find("average_output");
288  average_output_ = (it != global_dict.end());
289  }
290 
291  for (const auto& dict : tree_dict) {
292  lgb_trees_.emplace_back();
293  LGBTree& tree = lgb_trees_.back();
294 
295  auto it = dict.find("num_leaves");
296  TREELITE_CHECK(it != dict.end())
297  << "Ill-formed LightGBM model file: need num_leaves";
298  tree.num_leaves = TextToNumber<int>(it->second);
299 
300  it = dict.find("num_cat");
301  TREELITE_CHECK(it != dict.end()) << "Ill-formed LightGBM model file: need num_cat";
302  tree.num_cat = TextToNumber<int>(it->second);
303 
304  it = dict.find("leaf_value");
305  TREELITE_CHECK(it != dict.end() && !it->second.empty())
306  << "Ill-formed LightGBM model file: need leaf_value";
307  tree.leaf_value
308  = TextToArray<double>(it->second, tree.num_leaves);
309 
310  it = dict.find("decision_type");
311  if (tree.num_leaves <= 1) {
312  tree.decision_type = std::vector<int8_t>();
313  } else {
314  TREELITE_CHECK_GT(tree.num_leaves, 1);
315  if (it == dict.end()) {
316  tree.decision_type = std::vector<int8_t>(tree.num_leaves - 1, 0);
317  } else {
318  TREELITE_CHECK(!it->second.empty())
319  << "Ill-formed LightGBM model file: decision_type cannot be empty string";
320  tree.decision_type = TextToArray<int8_t>(it->second, tree.num_leaves - 1);
321  }
322  }
323 
324  if (tree.num_cat > 0) {
325  it = dict.find("cat_boundaries");
326  TREELITE_CHECK(it != dict.end() && !it->second.empty())
327  << "Ill-formed LightGBM model file: need cat_boundaries";
328  tree.cat_boundaries
329  = TextToArray<uint64_t>(it->second, tree.num_cat + 1);
330  it = dict.find("cat_threshold");
331  TREELITE_CHECK(it != dict.end() && !it->second.empty())
332  << "Ill-formed LightGBM model file: need cat_threshold";
333  tree.cat_threshold
334  = TextToArray<uint32_t>(it->second, static_cast<uint32_t>(tree.cat_boundaries.back()));
335  }
336 
337  it = dict.find("split_feature");
338  if (tree.num_leaves <= 1) {
339  tree.split_feature = std::vector<int>();
340  } else {
341  TREELITE_CHECK_GT(tree.num_leaves, 1);
342  TREELITE_CHECK(it != dict.end() && !it->second.empty())
343  << "Ill-formed LightGBM model file: need split_feature";
344  tree.split_feature = TextToArray<int>(it->second, tree.num_leaves - 1);
345  }
346 
347  it = dict.find("threshold");
348  if (tree.num_leaves <= 1) {
349  tree.threshold = std::vector<double>();
350  } else {
351  TREELITE_CHECK_GT(tree.num_leaves, 1);
352  TREELITE_CHECK(it != dict.end() && !it->second.empty())
353  << "Ill-formed LightGBM model file: need threshold";
354  tree.threshold = TextToArray<double>(it->second, tree.num_leaves - 1);
355  }
356 
357  it = dict.find("split_gain");
358  if (tree.num_leaves <= 1) {
359  tree.split_gain = std::vector<float>();
360  } else {
361  TREELITE_CHECK_GT(tree.num_leaves, 1);
362  if (it != dict.end()) {
363  TREELITE_CHECK(!it->second.empty())
364  << "Ill-formed LightGBM model file: split_gain cannot be empty string";
365  tree.split_gain = TextToArray<float>(it->second, tree.num_leaves - 1);
366  } else {
367  tree.split_gain = std::vector<float>();
368  }
369  }
370 
371  it = dict.find("internal_count");
372  if (tree.num_leaves <= 1) {
373  tree.internal_count = std::vector<int>();
374  } else {
375  TREELITE_CHECK_GT(tree.num_leaves, 1);
376  if (it != dict.end()) {
377  TREELITE_CHECK(!it->second.empty())
378  << "Ill-formed LightGBM model file: internal_count cannot be empty string";
379  tree.internal_count = TextToArray<int>(it->second, tree.num_leaves - 1);
380  } else {
381  tree.internal_count = std::vector<int>();
382  }
383  }
384 
385  it = dict.find("leaf_count");
386  if (tree.num_leaves == 0) {
387  tree.leaf_count = std::vector<int>();
388  } else {
389  TREELITE_CHECK_GT(tree.num_leaves, 0);
390  if (it != dict.end() && !it->second.empty()) {
391  tree.leaf_count = TextToArray<int>(it->second, tree.num_leaves);
392  } else {
393  tree.leaf_count = std::vector<int>();
394  }
395  }
396 
397  it = dict.find("left_child");
398  if (tree.num_leaves <= 1) {
399  tree.left_child = std::vector<int>();
400  } else {
401  TREELITE_CHECK_GT(tree.num_leaves, 1);
402  TREELITE_CHECK(it != dict.end() && !it->second.empty())
403  << "Ill-formed LightGBM model file: need left_child";
404  tree.left_child = TextToArray<int>(it->second, tree.num_leaves - 1);
405  }
406 
407  it = dict.find("right_child");
408  if (tree.num_leaves <= 1) {
409  tree.right_child = std::vector<int>();
410  } else {
411  TREELITE_CHECK_GT(tree.num_leaves, 1);
412  TREELITE_CHECK(it != dict.end() && !it->second.empty())
413  << "Ill-formed LightGBM model file: need right_child";
414  tree.right_child = TextToArray<int>(it->second, tree.num_leaves - 1);
415  }
416  }
417 
418  /* 2. Export model */
419  std::unique_ptr<treelite::Model> model_ptr = treelite::Model::Create<double, double>();
420  auto* model = dynamic_cast<treelite::ModelImpl<double, double>*>(model_ptr.get());
421  model->num_feature = max_feature_idx_ + 1;
422  model->average_tree_output = average_output_;
423  if (num_class_ > 1) {
424  // multi-class classifier
425  model->task_type = treelite::TaskType::kMultiClfGrovePerClass;
426  model->task_param.grove_per_class = true;
427  } else {
428  // binary classifier or regressor
429  model->task_type = treelite::TaskType::kBinaryClfRegr;
430  model->task_param.grove_per_class = false;
431  }
432  model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
433  model->task_param.num_class = num_class_;
434  model->task_param.leaf_vector_size = 1;
435 
436  // set correct prediction transform function, depending on objective function
437  if (obj_name_ == "multiclass") {
438  // validate num_class parameter
439  int num_class = -1;
440  int tmp;
441  for (const auto& str : obj_param_) {
442  auto tokens = Split(str, ':');
443  if (tokens.size() == 2 && tokens[0] == "num_class"
444  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
445  num_class = tmp;
446  break;
447  }
448  }
449  TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class)
450  << "Ill-formed LightGBM model file: not a valid multiclass objective";
451 
452  std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform));
453  } else if (obj_name_ == "multiclassova") {
454  // validate num_class and alpha parameters
455  int num_class = -1;
456  float alpha = -1.0f;
457  int tmp;
458  float tmp2;
459  for (const auto& str : obj_param_) {
460  auto tokens = Split(str, ':');
461  if (tokens.size() == 2) {
462  if (tokens[0] == "num_class"
463  && (tmp = TextToNumber<int>(tokens[1])) >= 0) {
464  num_class = tmp;
465  } else if (tokens[0] == "sigmoid"
466  && (tmp2 = TextToNumber<float>(tokens[1])) > 0.0f) {
467  alpha = tmp2;
468  }
469  }
470  }
471  TREELITE_CHECK(num_class >= 0 && static_cast<size_t>(num_class) == model->task_param.num_class
472  && alpha > 0.0f)
473  << "Ill-formed LightGBM model file: not a valid multiclassova objective";
474 
475  std::strncpy(model->param.pred_transform, "multiclass_ova",
476  sizeof(model->param.pred_transform));
477  model->param.sigmoid_alpha = alpha;
478  } else if (obj_name_ == "binary") {
479  // validate alpha parameter
480  float alpha = -1.0f;
481  float tmp;
482  for (const auto& str : obj_param_) {
483  auto tokens = Split(str, ':');
484  if (tokens.size() == 2 && tokens[0] == "sigmoid"
485  && (tmp = TextToNumber<float>(tokens[1])) > 0.0f) {
486  alpha = tmp;
487  break;
488  }
489  }
490  TREELITE_CHECK_GT(alpha, 0.0f)
491  << "Ill-formed LightGBM model file: not a valid binary objective";
492 
493  std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform));
494  model->param.sigmoid_alpha = alpha;
495  } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") {
496  std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform));
497  model->param.sigmoid_alpha = 1.0f;
498  } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
499  std::strncpy(model->param.pred_transform, "logarithm_one_plus_exp",
500  sizeof(model->param.pred_transform));
501  } else if (obj_name_ == "poisson" || obj_name_ == "gamma" || obj_name_ == "tweedie") {
502  std::strncpy(model->param.pred_transform, "exponential",
503  sizeof(model->param.pred_transform));
504  } else if (obj_name_ == "regression" || obj_name_ == "regression_l1" || obj_name_ == "huber"
505  || obj_name_ == "fair" || obj_name_ == "quantile" || obj_name_ == "mape") {
506  // Regression family
507  bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(), "sqrt") != obj_param_.cend());
508  if (sqrt) {
509  std::strncpy(model->param.pred_transform, "signed_square",
510  sizeof(model->param.pred_transform));
511  } else {
512  std::strncpy(model->param.pred_transform, "identity",
513  sizeof(model->param.pred_transform));
514  }
515  } else if (obj_name_ == "lambdarank" || obj_name_ == "rank_xendcg" || obj_name_ == "custom") {
516  // Ranking family, or a custom user-defined objective
517  std::strncpy(model->param.pred_transform, "identity",
518  sizeof(model->param.pred_transform));
519  } else {
520  TREELITE_LOG(FATAL) << "Unrecognized objective: " << obj_name_;
521  }
522 
523  // traverse trees
524  for (const auto& lgb_tree : lgb_trees_) {
525  model->trees.emplace_back();
526  treelite::Tree<double, double>& tree = model->trees.back();
527  tree.Init();
528 
529  // assign node ID's so that a breadth-wise traversal would yield
530  // the monotonic sequence 0, 1, 2, ...
531  std::queue<std::pair<int, int>> Q; // (old ID, new ID) pair
532  if (lgb_tree.num_leaves == 0) {
533  continue;
534  } else if (lgb_tree.num_leaves == 1) {
535  // A constant-value tree with a single root node that's also a leaf
536  Q.push({-1, 0});
537  } else {
538  Q.push({0, 0});
539  }
540  while (!Q.empty()) {
541  int old_id, new_id;
542  std::tie(old_id, new_id) = Q.front(); Q.pop();
543  if (old_id < 0) { // leaf
544  const double leaf_value = lgb_tree.leaf_value[~old_id];
545  tree.SetLeaf(new_id, static_cast<double>(leaf_value));
546  if (!lgb_tree.leaf_count.empty()) {
547  const int data_count = lgb_tree.leaf_count[~old_id];
548  TREELITE_CHECK_GE(data_count, 0);
549  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
550  }
551  } else { // non-leaf
552  const auto split_index = static_cast<unsigned>(lgb_tree.split_feature[old_id]);
553  const auto missing_type = GetMissingType(lgb_tree.decision_type[old_id]);
554 
555  tree.AddChilds(new_id);
556  if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
557  // categorical split
558  const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
559  const std::vector<uint32_t> left_categories
560  = BitsetToList(lgb_tree.cat_threshold.data()
561  + lgb_tree.cat_boundaries[cat_idx],
562  lgb_tree.cat_boundaries[cat_idx + 1]
563  - lgb_tree.cat_boundaries[cat_idx]);
564  // For categorical splits, we ignore the missing type field. NaNs always get mapped to
565  // the right child node.
566  bool default_left = false;
567  tree.SetCategoricalSplit(new_id, split_index, default_left, left_categories, false);
568  } else {
569  // numerical split
570  const auto threshold = static_cast<double>(lgb_tree.threshold[old_id]);
571  bool default_left
572  = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask);
573  const treelite::Operator cmp_op = treelite::Operator::kLE;
574  const bool missing_value_to_zero = (missing_type != MissingType::kNaN);
575  if (missing_value_to_zero) {
576  // If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so
577  // we need to override the default_left flag
578  default_left = 0.0 <= threshold;
579  }
580  tree.SetNumericalSplit(new_id, split_index, threshold, default_left, cmp_op);
581  }
582  if (!lgb_tree.internal_count.empty()) {
583  const int data_count = lgb_tree.internal_count[old_id];
584  TREELITE_CHECK_GE(data_count, 0);
585  tree.SetDataCount(new_id, static_cast<size_t>(data_count));
586  }
587  if (!lgb_tree.split_gain.empty()) {
588  tree.SetGain(new_id, static_cast<double>(lgb_tree.split_gain[old_id]));
589  }
590  Q.push({lgb_tree.left_child[old_id], tree.LeftChild(new_id)});
591  Q.push({lgb_tree.right_child[old_id], tree.RightChild(new_id)});
592  }
593  }
594  }
595  return model_ptr;
596 }
597 
598 } // anonymous namespace
Collection of front-end methods to load or construct ensemble model.
void Init()
initialize the model with a single root node
Definition: tree_impl.h:649
model structure for tree ensemble
std::unique_ptr< treelite::Model > LoadLightGBMModel(const char *filename)
load a model file generated by LightGBM (Microsoft/LightGBM). The model file must contain a decision ...
Definition: lightgbm.cc:28
in-memory representation of a decision tree
Definition: tree.h:214
logging facility for Treelite
std::unique_ptr< treelite::Model > LoadLightGBMModelFromString(const char *model_str)
Load a LightGBM model from a string. The string should be created with the model_to_string() method i...
Definition: lightgbm.cc:33
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:589
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:579
int LeftChild(int nid) const
Getters.
Definition: tree.h:352
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:359
void AddChilds(int nid)
add child nodes to node
Definition: tree_impl.h:663
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, const std::vector< uint32_t > &categories_list, bool categories_list_right_child)
create a categorical split
Definition: tree_impl.h:717
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:714
void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp)
Setters.
Definition: tree_impl.h:701
Operator
comparison operators
Definition: base.h:26
void SetLeaf(int nid, LeafOutputType value)
set the leaf value of the node
Definition: tree_impl.h:753