Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/version.h>
12 #include <algorithm>
13 #include <map>
14 #include <memory>
15 #include <ostream>
16 #include <sstream>
17 #include <string>
18 #include <vector>
19 #include <utility>
20 #include <type_traits>
21 #include <limits>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <cstdio>
26 
27 #define __TREELITE_STR(x) #x
28 #define _TREELITE_STR(x) __TREELITE_STR(x)
29 
30 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
31 
32 namespace treelite {
33 
34 class GTILBridge;
35 
36 template <typename ThresholdType, typename LeafOutputType>
37 class ModelImpl;
38 
39 // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
40 // to hold only 1-D arrays with stride 1.
41 struct PyBufferFrame {
42  void* buf;
43  char* format;
44  std::size_t itemsize;
45  std::size_t nitem;
46 };
47 
48 static_assert(std::is_pod<PyBufferFrame>::value, "PyBufferFrame must be a POD type");
49 
50 template <typename T>
52  public:
54  ~ContiguousArray();
55  // NOTE: use Clone to make deep copy; copy constructors disabled
56  ContiguousArray(const ContiguousArray&) = delete;
57  ContiguousArray& operator=(const ContiguousArray&) = delete;
58  ContiguousArray(ContiguousArray&& other) noexcept;
59  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
60  inline ContiguousArray Clone() const;
61  inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
62  inline T* Data();
63  inline const T* Data() const;
64  inline T* End();
65  inline const T* End() const;
66  inline T& Back();
67  inline const T& Back() const;
68  inline std::size_t Size() const;
69  inline bool Empty() const;
70  inline void Reserve(std::size_t newsize);
71  inline void Resize(std::size_t newsize);
72  inline void Resize(std::size_t newsize, T t);
73  inline void Clear();
74  inline void PushBack(T t);
75  inline void Extend(const std::vector<T>& other);
76  /* Unsafe access, no bounds checking */
77  inline T& operator[](std::size_t idx);
78  inline const T& operator[](std::size_t idx) const;
79  /* Safe access, with bounds checking */
80  inline T& at(std::size_t idx);
81  inline const T& at(std::size_t idx) const;
82  /* Safe access, with bounds checking + check against non-existent node (<0) */
83  inline T& at(int idx);
84  inline const T& at(int idx) const;
85  static_assert(std::is_pod<T>::value, "T must be POD");
86 
87  private:
88  T* buffer_;
89  std::size_t size_;
90  std::size_t capacity_;
91  bool owned_buffer_;
92 };
93 
100 enum class TaskType : uint8_t {
108  kBinaryClfRegr = 0,
127  kMultiClfGrovePerClass = 1,
143  kMultiClfProbDistLeaf = 2,
160  kMultiClfCategLeaf = 3
161 };
162 
163 inline std::string TaskTypeToString(TaskType type) {
164  switch (type) {
165  case TaskType::kBinaryClfRegr: return "BinaryClfRegr";
166  case TaskType::kMultiClfGrovePerClass: return "MultiClfGrovePerClass";
167  case TaskType::kMultiClfProbDistLeaf: return "MultiClfProbDistLeaf";
168  case TaskType::kMultiClfCategLeaf: return "MultiClfCategLeaf";
169  default: return "";
170  }
171 }
172 
174 struct TaskParam {
175  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
177  OutputType output_type;
193  unsigned int num_class;
200  unsigned int leaf_vector_size;
201 };
202 
203 inline std::string OutputTypeToString(TaskParam::OutputType type) {
204  switch (type) {
205  case TaskParam::OutputType::kFloat: return "float";
206  case TaskParam::OutputType::kInt: return "int";
207  default: return "";
208  }
209 }
210 
211 static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");
212 
214 template <typename ThresholdType, typename LeafOutputType>
215 class Tree {
216  public:
218  struct Node {
221  inline void Init();
223  union Info {
224  LeafOutputType leaf_value; // for leaf nodes
225  ThresholdType threshold; // for non-leaf nodes
226  };
228  std::int32_t cleft_, cright_;
233  std::uint32_t sindex_;
240  std::uint64_t data_count_;
247  double sum_hess_;
251  double gain_;
266  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
267  * node or the left child node. True if the right child, False otherwise */
268  bool categories_list_right_child_;
269 
271  inline int LeftChild() const {
272  return cleft_;
273  }
274  inline int RightChild() const {
275  return cright_;
276  }
277  inline bool DefaultLeft() const {
278  // Extract the most significant bit (MSB) of sindex_, which encodes the default_left field
279  return (sindex_ >> 31U) != 0;
280  }
281  inline int DefaultChild() const {
282  // Extract the most significant bit (MSB) of sindex_, which encodes the default_left field
283  return ((sindex_ >> 31U) != 0) ? cleft_ : cright_;
284  }
285  inline std::uint32_t SplitIndex() const {
286  // Extract all bits except the most significant bit (MSB) from sindex_.
287  return (sindex_ & ((1U << 31U) - 1U));
288  }
289  inline bool IsLeaf() const {
290  return cleft_ == -1;
291  }
292  inline LeafOutputType LeafValue() const {
293  return info_.leaf_value;
294  }
295  inline ThresholdType Threshold() const {
296  return info_.threshold;
297  }
298  inline Operator ComparisonOp() const {
299  return cmp_;
300  }
301  inline SplitFeatureType SplitType() const {
302  return split_type_;
303  }
304  inline bool HasDataCount() const {
305  return data_count_present_;
306  }
307  inline std::uint64_t DataCount() const {
308  return data_count_;
309  }
310  inline bool HasSumHess() const {
311  return sum_hess_present_;
312  }
313  inline double SumHess() const {
314  return sum_hess_;
315  }
316  inline bool HasGain() const {
317  return gain_present_;
318  }
319  inline double Gain() const {
320  return gain_;
321  }
322  inline bool CategoriesListRightChild() const {
323  return categories_list_right_child_;
324  }
325  };
326 
327  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
328  static_assert(std::is_same<ThresholdType, float>::value
329  || std::is_same<ThresholdType, double>::value,
330  "ThresholdType must be either float32 or float64");
331  static_assert(std::is_same<LeafOutputType, uint32_t>::value
332  || std::is_same<LeafOutputType, float>::value
333  || std::is_same<LeafOutputType, double>::value,
334  "LeafOutputType must be one of uint32_t, float32 or float64");
335  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
336  || std::is_same<LeafOutputType, uint32_t>::value,
337  "Unsupported combination of ThresholdType and LeafOutputType");
338  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
339  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
340  "Node size incorrect");
341 
342  Tree() = default;
343  ~Tree() = default;
344  Tree(const Tree&) = delete;
345  Tree& operator=(const Tree&) = delete;
346  Tree(Tree&&) noexcept = default;
347  Tree& operator=(Tree&&) noexcept = default;
348 
349  inline Tree<ThresholdType, LeafOutputType> Clone() const;
350 
351  inline const char* GetFormatStringForNode();
352  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
353  inline void SerializeToFile(FILE* dest_fp);
354  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
355  std::vector<PyBufferFrame>::iterator end);
356  inline void DeserializeFromFile(FILE* src_fp);
357 
358  private:
359  // vector of nodes
360  ContiguousArray<Node> nodes_;
361  ContiguousArray<LeafOutputType> leaf_vector_;
362  // Map nid to the start and end index in leaf_vector_
363  // We could use std::pair, but it is not POD, so easier to use two vectors
364  // here
365  ContiguousArray<std::size_t> leaf_vector_begin_;
366  ContiguousArray<std::size_t> leaf_vector_end_;
367  ContiguousArray<std::uint32_t> matching_categories_;
368  ContiguousArray<std::size_t> matching_categories_offset_;
369  bool has_categorical_split_{false};
370 
371  template <typename WriterType, typename X, typename Y>
372  friend void DumpModelAsJSON(WriterType& writer, const ModelImpl<X, Y>& model);
373  template <typename WriterType, typename X, typename Y>
374  friend void DumpTreeAsJSON(WriterType& writer, const Tree<X, Y>& tree);
375 
376  // allocate a new node
377  inline int AllocNode();
378 
379  // utility functions used for serialization, internal use only
380  template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
381  inline void
382  SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
383  CompositeArrayHandler composite_array_handler);
384  template <typename ScalarHandler, typename ArrayHandler>
385  inline void
386  DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
387 
388  friend class GTILBridge; // bridge to enable optimized access to nodes from GTIL
389 
390  public:
394  inline void Init();
399  inline void AddChilds(int nid);
400 
406  inline int LeftChild(int nid) const {
407  return nodes_[nid].LeftChild();
408  }
413  inline int RightChild(int nid) const {
414  return nodes_[nid].RightChild();
415  }
420  inline int DefaultChild(int nid) const {
421  return nodes_[nid].DefaultChild();
422  }
427  inline std::uint32_t SplitIndex(int nid) const {
428  return nodes_[nid].SplitIndex();
429  }
434  inline bool DefaultLeft(int nid) const {
435  return nodes_[nid].DefaultLeft();
436  }
441  inline bool IsLeaf(int nid) const {
442  return nodes_[nid].IsLeaf();
443  }
448  inline LeafOutputType LeafValue(int nid) const {
449  return nodes_[nid].LeafValue();
450  }
455  inline std::vector<LeafOutputType> LeafVector(int nid) const {
456  const std::size_t offset_begin = leaf_vector_begin_[nid];
457  const std::size_t offset_end = leaf_vector_end_[nid];
458  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
459  // Return empty vector, to indicate the lack of leaf vector
460  return std::vector<LeafOutputType>();
461  }
462  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
463  &leaf_vector_[offset_end]);
464  // Use unsafe access here, since we may need to take the address of one past the last
465  // element, to follow with the range semantic of std::vector<>.
466  }
471  inline bool HasLeafVector(int nid) const {
472  return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
473  }
478  inline ThresholdType Threshold(int nid) const {
479  return nodes_[nid].Threshold();
480  }
485  inline Operator ComparisonOp(int nid) const {
486  return nodes_[nid].ComparisonOp();
487  }
496  inline std::vector<std::uint32_t> MatchingCategories(int nid) const {
497  const std::size_t offset_begin = matching_categories_offset_[nid];
498  const std::size_t offset_end = matching_categories_offset_[nid + 1];
499  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
500  // Return empty vector, to indicate the lack of any matching categories
501  // The node might be a numerical split
502  return std::vector<std::uint32_t>();
503  }
504  return std::vector<std::uint32_t>(&matching_categories_[offset_begin],
505  &matching_categories_[offset_end]);
506  // Use unsafe access here, since we may need to take the address of one past the last
507  // element, to follow with the range semantic of std::vector<>.
508  }
513  inline SplitFeatureType SplitType(int nid) const {
514  return nodes_[nid].SplitType();
515  }
520  inline bool HasDataCount(int nid) const {
521  return nodes_[nid].HasDataCount();
522  }
527  inline std::uint64_t DataCount(int nid) const {
528  return nodes_[nid].DataCount();
529  }
530 
535  inline bool HasSumHess(int nid) const {
536  return nodes_[nid].HasSumHess();
537  }
542  inline double SumHess(int nid) const {
543  return nodes_[nid].SumHess();
544  }
549  inline bool HasGain(int nid) const {
550  return nodes_[nid].HasGain();
551  }
556  inline double Gain(int nid) const {
557  return nodes_[nid].Gain();
558  }
564  inline bool CategoriesListRightChild(int nid) const {
565  return nodes_[nid].CategoriesListRightChild();
566  }
567 
571  inline bool HasCategoricalSplit() const {
572  return has_categorical_split_;
573  }
574 
585  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
586  bool default_left, Operator cmp);
599  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
600  const std::vector<uint32_t>& categories_list,
601  bool categories_list_right_child);
607  inline void SetLeaf(int nid, LeafOutputType value);
613  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
619  inline void SetSumHess(int nid, double sum_hess) {
620  Node& node = nodes_.at(nid);
621  node.sum_hess_ = sum_hess;
622  node.sum_hess_present_ = true;
623  }
629  inline void SetDataCount(int nid, uint64_t data_count) {
630  Node& node = nodes_.at(nid);
631  node.data_count_ = data_count;
632  node.data_count_present_ = true;
633  }
639  inline void SetGain(int nid, double gain) {
640  Node& node = nodes_.at(nid);
641  node.gain_ = gain;
642  node.gain_present_ = true;
643  }
644 };
645 
646 struct ModelParam {
668  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
684  float ratio_c;
691  float global_bias;
694  ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
695  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
696  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
697  }
698  ~ModelParam() = default;
699  ModelParam(const ModelParam&) = default;
700  ModelParam& operator=(const ModelParam&) = default;
701  ModelParam(ModelParam&&) = default;
702  ModelParam& operator=(ModelParam&&) = default;
703 
704  template<typename Container>
705  inline std::vector<std::pair<std::string, std::string>>
706  InitAllowUnknown(const Container &kwargs);
707  inline std::map<std::string, std::string> __DICT__() const;
708 };
709 
710 static_assert(std::is_standard_layout<ModelParam>::value,
711  "ModelParam must be in the standard layout");
712 
713 inline void InitParamAndCheck(ModelParam* param,
714  const std::vector<std::pair<std::string, std::string>>& cfg);
715 
717 class Model {
718  public:
720  Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
721  patch_ver_(TREELITE_VER_PATCH) {}
722  virtual ~Model() = default;
723  Model(const Model&) = delete;
724  Model& operator=(const Model&) = delete;
725  Model(Model&&) = default;
726  Model& operator=(Model&&) = default;
727 
728  template <typename ThresholdType, typename LeafOutputType>
729  inline static std::unique_ptr<Model> Create();
730  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
731  inline TypeInfo GetThresholdType() const {
732  return threshold_type_;
733  }
734  inline TypeInfo GetLeafOutputType() const {
735  return leaf_output_type_;
736  }
737  template <typename Func>
738  inline auto Dispatch(Func func);
739  template <typename Func>
740  inline auto Dispatch(Func func) const;
741 
742  virtual std::size_t GetNumTree() const = 0;
743  virtual void SetTreeLimit(std::size_t limit) = 0;
744  virtual void DumpAsJSON(std::ostream& fo, bool pretty_print) const = 0;
745 
746  inline std::string DumpAsJSON(bool pretty_print) const {
747  std::ostringstream oss;
748  DumpAsJSON(oss, pretty_print);
749  return oss.str();
750  }
751 
752  /* In-memory serialization, zero-copy */
753  std::vector<PyBufferFrame> GetPyBuffer();
754  static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
755 
756  /* Serialization to a file stream */
757  void SerializeToFile(FILE* dest_fp);
758  static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
759 
773 
774  private:
775  int major_ver_, minor_ver_, patch_ver_;
776  TypeInfo threshold_type_;
777  TypeInfo leaf_output_type_;
778  // Internal functions for serialization
779  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
780  virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
781  virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
782  std::vector<PyBufferFrame>::iterator end) = 0;
783  virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
784  template <typename HeaderPrimitiveFieldHandlerFunc>
785  inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
786  template <typename HeaderPrimitiveFieldHandlerFunc>
787  inline static void DeserializeTemplate(
788  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
789  TypeInfo& threshold_type, TypeInfo& leaf_output_type);
790 };
791 
792 template <typename ThresholdType, typename LeafOutputType>
793 class ModelImpl : public Model {
794  public:
796  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
797 
799  ModelImpl() = default;
800  ~ModelImpl() override = default;
801  ModelImpl(const ModelImpl&) = delete;
802  ModelImpl& operator=(const ModelImpl&) = delete;
803  ModelImpl(ModelImpl&&) noexcept = default;
804  ModelImpl& operator=(ModelImpl&&) noexcept = default;
805 
806  void DumpAsJSON(std::ostream& fo, bool pretty_print) const override;
807  inline std::size_t GetNumTree() const override {
808  return trees.size();
809  }
810  void SetTreeLimit(std::size_t limit) override {
811  return trees.resize(limit);
812  }
813 
814  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
815  inline void SerializeToFileImpl(FILE* dest_fp) override;
816  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
817  std::vector<PyBufferFrame>::iterator end) override;
818  inline void DeserializeFromFileImpl(FILE* src_fp) override;
819 
820  private:
821  template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
822  typename TreeHandlerFunc>
823  inline void SerializeTemplate(
824  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
825  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
826  TreeHandlerFunc tree_handler);
827  template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
828  inline void DeserializeTemplate(
829  size_t num_tree,
830  HeaderFieldHandlerFunc header_field_handler,
831  TreeHandlerFunc tree_handler);
832 };
833 
834 } // namespace treelite
835 
836 #include "tree_impl.h"
837 
838 #endif // TREELITE_TREE_H_
ModelParam param
extra parameters
Definition: tree.h:772
SplitFeatureType split_type_
feature split type
Definition: tree.h:253
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:485
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:265
SplitFeatureType
feature split type
Definition: base.h:22
std::int32_t cleft_
pointer to left and right children
Definition: tree.h:228
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:520
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:549
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:185
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:174
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:259
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:455
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition: tree.h:571
TaskType
Enum type representing the task type.
Definition: tree.h:100
bool average_tree_output
whether to average tree outputs
Definition: tree.h:768
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:676
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:420
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:261
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:200
tree node
Definition: tree.h:218
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:619
in-memory representation of a decision tree
Definition: tree.h:215
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:247
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:193
float global_bias
global bias of the model
Definition: tree.h:691
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:251
TaskType task_type
Task type.
Definition: tree.h:766
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
store either leaf value or decision threshold
Definition: tree.h:223
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:796
std::uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:427
double SumHess(int nid) const
get hessian sum
Definition: tree.h:542
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:639
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:629
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:564
int num_nodes
number of nodes
Definition: tree.h:392
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:513
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:448
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:684
Model()
disable copy; use default move
Definition: tree.h:720
std::uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:240
int LeftChild(int nid) const
Getters.
Definition: tree.h:406
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:770
defines configuration macros of Treelite
std::uint64_t DataCount(int nid) const
get data count
Definition: tree.h:527
double Gain(int nid) const
get gain value
Definition: tree.h:556
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:413
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:177
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:263
thin wrapper for tree ensemble model
Definition: tree.h:717
std::uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:233
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:434
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:535
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:441
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:764
int LeftChild() const
Getters.
Definition: tree.h:271
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:478
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:471
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:496
Info info_
storage for leaf value or decision threshold
Definition: tree.h:235
Operator
comparison operators
Definition: base.h:26