Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/version.h>
12 #include <algorithm>
13 #include <map>
14 #include <memory>
15 #include <ostream>
16 #include <sstream>
17 #include <string>
18 #include <vector>
19 #include <utility>
20 #include <type_traits>
21 #include <limits>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <cstdio>
26 
27 #define __TREELITE_STR(x) #x
28 #define _TREELITE_STR(x) __TREELITE_STR(x)
29 
30 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
31 
32 /* Indicator that certain functions should be visible from a library (Windows only) */
33 #if defined(_MSC_VER) || defined(_WIN32)
34 #define TREELITE_DLL_EXPORT __declspec(dllexport)
35 #else
36 #define TREELITE_DLL_EXPORT
37 #endif
38 
39 namespace treelite {
40 
41 class GTILBridge;
42 
43 template <typename ThresholdType, typename LeafOutputType>
44 class ModelImpl;
45 
46 // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
47 // to hold only 1-D arrays with stride 1.
48 struct PyBufferFrame {
49  void* buf;
50  char* format;
51  std::size_t itemsize;
52  std::size_t nitem;
53 };
54 
55 static_assert(std::is_pod<PyBufferFrame>::value, "PyBufferFrame must be a POD type");
56 
57 template <typename T>
59  public:
61  ~ContiguousArray();
62  // NOTE: use Clone to make deep copy; copy constructors disabled
63  ContiguousArray(const ContiguousArray&) = delete;
64  ContiguousArray& operator=(const ContiguousArray&) = delete;
65  ContiguousArray(ContiguousArray&& other) noexcept;
66  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
67  inline ContiguousArray Clone() const;
68  inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
69  inline T* Data();
70  inline const T* Data() const;
71  inline T* End();
72  inline const T* End() const;
73  inline T& Back();
74  inline const T& Back() const;
75  inline std::size_t Size() const;
76  inline bool Empty() const;
77  inline void Reserve(std::size_t newsize);
78  inline void Resize(std::size_t newsize);
79  inline void Resize(std::size_t newsize, T t);
80  inline void Clear();
81  inline void PushBack(T t);
82  inline void Extend(const std::vector<T>& other);
83  /* Unsafe access, no bounds checking */
84  inline T& operator[](std::size_t idx);
85  inline const T& operator[](std::size_t idx) const;
86  /* Safe access, with bounds checking */
87  inline T& at(std::size_t idx);
88  inline const T& at(std::size_t idx) const;
89  /* Safe access, with bounds checking + check against non-existent node (<0) */
90  inline T& at(int idx);
91  inline const T& at(int idx) const;
92  static_assert(std::is_pod<T>::value, "T must be POD");
93 
94  private:
95  T* buffer_;
96  std::size_t size_;
97  std::size_t capacity_;
98  bool owned_buffer_;
99 };
100 
107 enum class TaskType : uint8_t {
115  kBinaryClfRegr = 0,
134  kMultiClfGrovePerClass = 1,
150  kMultiClfProbDistLeaf = 2,
167  kMultiClfCategLeaf = 3
168 };
169 
170 inline std::string TaskTypeToString(TaskType type) {
171  switch (type) {
172  case TaskType::kBinaryClfRegr: return "BinaryClfRegr";
173  case TaskType::kMultiClfGrovePerClass: return "MultiClfGrovePerClass";
174  case TaskType::kMultiClfProbDistLeaf: return "MultiClfProbDistLeaf";
175  case TaskType::kMultiClfCategLeaf: return "MultiClfCategLeaf";
176  default: return "";
177  }
178 }
179 
181 struct TaskParam {
182  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
184  OutputType output_type;
200  unsigned int num_class;
207  unsigned int leaf_vector_size;
208 };
209 
210 inline std::string OutputTypeToString(TaskParam::OutputType type) {
211  switch (type) {
212  case TaskParam::OutputType::kFloat: return "float";
213  case TaskParam::OutputType::kInt: return "int";
214  default: return "";
215  }
216 }
217 
218 static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");
219 
221 template <typename ThresholdType, typename LeafOutputType>
222 class Tree {
223  public:
225  struct Node {
228  inline void Init();
230  union Info {
231  LeafOutputType leaf_value; // for leaf nodes
232  ThresholdType threshold; // for non-leaf nodes
233  };
235  std::int32_t cleft_, cright_;
240  std::uint32_t sindex_;
247  std::uint64_t data_count_;
254  double sum_hess_;
258  double gain_;
273  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
274  * node or the left child node. True if the right child, False otherwise */
275  bool categories_list_right_child_;
276 
278  inline int LeftChild() const {
279  return cleft_;
280  }
281  inline int RightChild() const {
282  return cright_;
283  }
284  inline bool DefaultLeft() const {
285  // Extract the most significant bit (MSB) of sindex_, which encodes the default_left field
286  return (sindex_ >> 31U) != 0;
287  }
288  inline int DefaultChild() const {
289  // Extract the most significant bit (MSB) of sindex_, which encodes the default_left field
290  return ((sindex_ >> 31U) != 0) ? cleft_ : cright_;
291  }
292  inline std::uint32_t SplitIndex() const {
293  // Extract all bits except the most significant bit (MSB) from sindex_.
294  return (sindex_ & ((1U << 31U) - 1U));
295  }
296  inline bool IsLeaf() const {
297  return cleft_ == -1;
298  }
299  inline LeafOutputType LeafValue() const {
300  return info_.leaf_value;
301  }
302  inline ThresholdType Threshold() const {
303  return info_.threshold;
304  }
305  inline Operator ComparisonOp() const {
306  return cmp_;
307  }
308  inline SplitFeatureType SplitType() const {
309  return split_type_;
310  }
311  inline bool HasDataCount() const {
312  return data_count_present_;
313  }
314  inline std::uint64_t DataCount() const {
315  return data_count_;
316  }
317  inline bool HasSumHess() const {
318  return sum_hess_present_;
319  }
320  inline double SumHess() const {
321  return sum_hess_;
322  }
323  inline bool HasGain() const {
324  return gain_present_;
325  }
326  inline double Gain() const {
327  return gain_;
328  }
329  inline bool CategoriesListRightChild() const {
330  return categories_list_right_child_;
331  }
332  };
333 
334  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
335  static_assert(std::is_same<ThresholdType, float>::value
336  || std::is_same<ThresholdType, double>::value,
337  "ThresholdType must be either float32 or float64");
338  static_assert(std::is_same<LeafOutputType, uint32_t>::value
339  || std::is_same<LeafOutputType, float>::value
340  || std::is_same<LeafOutputType, double>::value,
341  "LeafOutputType must be one of uint32_t, float32 or float64");
342  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
343  || std::is_same<LeafOutputType, uint32_t>::value,
344  "Unsupported combination of ThresholdType and LeafOutputType");
345  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
346  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
347  "Node size incorrect");
348 
349  explicit Tree(bool use_opt_field = true);
350  ~Tree() = default;
351  Tree(const Tree&) = delete;
352  Tree& operator=(const Tree&) = delete;
353  Tree(Tree&&) noexcept = default;
354  Tree& operator=(Tree&&) noexcept = default;
355 
356  inline Tree<ThresholdType, LeafOutputType> Clone() const;
357 
358  inline const char* GetFormatStringForNode();
359  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
360  inline void SerializeToFile(FILE* dest_fp);
361  // Load a Tree object from a sequence of PyBuffer frames
362  // Returns the updated position of the cursor in the sequence
363  inline std::vector<PyBufferFrame>::iterator
364  InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it);
365  inline void DeserializeFromFile(FILE* src_fp);
366 
367  private:
368  // vector of nodes
369  ContiguousArray<Node> nodes_;
370  ContiguousArray<LeafOutputType> leaf_vector_;
371  // Map nid to the start and end index in leaf_vector_
372  // We could use std::pair, but it is not POD, so easier to use two vectors
373  // here
374  ContiguousArray<std::size_t> leaf_vector_begin_;
375  ContiguousArray<std::size_t> leaf_vector_end_;
376  ContiguousArray<std::uint32_t> matching_categories_;
377  ContiguousArray<std::size_t> matching_categories_offset_;
378  bool has_categorical_split_{false};
379 
380  /* Note: the following member fields shall be re-computed at serialization time */
381  // Whether to use optional fields
382  bool use_opt_field_{false};
383  // Number of optional fields in the extension slots
384  int32_t num_opt_field_per_tree_{0};
385  int32_t num_opt_field_per_node_{0};
386 
387  template <typename WriterType, typename X, typename Y>
388  friend void DumpModelAsJSON(WriterType& writer, const ModelImpl<X, Y>& model);
389  template <typename WriterType, typename X, typename Y>
390  friend void DumpTreeAsJSON(WriterType& writer, const Tree<X, Y>& tree);
391 
392  // allocate a new node
393  inline int AllocNode();
394 
395  // utility functions used for serialization, internal use only
396  template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
397  inline void
398  SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
399  CompositeArrayHandler composite_array_handler);
400  template <typename ScalarHandler, typename ArrayHandler, typename SkipOptFieldHandlerFunc>
401  inline void
402  DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler,
403  SkipOptFieldHandlerFunc skip_opt_field_handler);
404 
405  friend class GTILBridge; // bridge to enable optimized access to nodes from GTIL
406 
407  public:
409  int num_nodes{0};
411  inline void Init();
416  inline void AddChilds(int nid);
417 
423  inline int LeftChild(int nid) const {
424  return nodes_[nid].LeftChild();
425  }
430  inline int RightChild(int nid) const {
431  return nodes_[nid].RightChild();
432  }
437  inline int DefaultChild(int nid) const {
438  return nodes_[nid].DefaultChild();
439  }
444  inline std::uint32_t SplitIndex(int nid) const {
445  return nodes_[nid].SplitIndex();
446  }
451  inline bool DefaultLeft(int nid) const {
452  return nodes_[nid].DefaultLeft();
453  }
458  inline bool IsLeaf(int nid) const {
459  return nodes_[nid].IsLeaf();
460  }
465  inline LeafOutputType LeafValue(int nid) const {
466  return nodes_[nid].LeafValue();
467  }
472  inline std::vector<LeafOutputType> LeafVector(int nid) const {
473  const std::size_t offset_begin = leaf_vector_begin_[nid];
474  const std::size_t offset_end = leaf_vector_end_[nid];
475  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
476  // Return empty vector, to indicate the lack of leaf vector
477  return std::vector<LeafOutputType>();
478  }
479  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
480  &leaf_vector_[offset_end]);
481  // Use unsafe access here, since we may need to take the address of one past the last
482  // element, to follow with the range semantic of std::vector<>.
483  }
488  inline bool HasLeafVector(int nid) const {
489  return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
490  }
495  inline ThresholdType Threshold(int nid) const {
496  return nodes_[nid].Threshold();
497  }
502  inline Operator ComparisonOp(int nid) const {
503  return nodes_[nid].ComparisonOp();
504  }
513  inline std::vector<std::uint32_t> MatchingCategories(int nid) const {
514  const std::size_t offset_begin = matching_categories_offset_[nid];
515  const std::size_t offset_end = matching_categories_offset_[nid + 1];
516  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
517  // Return empty vector, to indicate the lack of any matching categories
518  // The node might be a numerical split
519  return std::vector<std::uint32_t>();
520  }
521  return std::vector<std::uint32_t>(&matching_categories_[offset_begin],
522  &matching_categories_[offset_end]);
523  // Use unsafe access here, since we may need to take the address of one past the last
524  // element, to follow with the range semantic of std::vector<>.
525  }
530  inline SplitFeatureType SplitType(int nid) const {
531  return nodes_[nid].SplitType();
532  }
537  inline bool HasDataCount(int nid) const {
538  return nodes_[nid].HasDataCount();
539  }
544  inline std::uint64_t DataCount(int nid) const {
545  return nodes_[nid].DataCount();
546  }
547 
552  inline bool HasSumHess(int nid) const {
553  return nodes_[nid].HasSumHess();
554  }
559  inline double SumHess(int nid) const {
560  return nodes_[nid].SumHess();
561  }
566  inline bool HasGain(int nid) const {
567  return nodes_[nid].HasGain();
568  }
573  inline double Gain(int nid) const {
574  return nodes_[nid].Gain();
575  }
581  inline bool CategoriesListRightChild(int nid) const {
582  return nodes_[nid].CategoriesListRightChild();
583  }
584 
588  inline bool HasCategoricalSplit() const {
589  return has_categorical_split_;
590  }
591 
602  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
603  bool default_left, Operator cmp);
616  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
617  const std::vector<uint32_t>& categories_list,
618  bool categories_list_right_child);
624  inline void SetLeaf(int nid, LeafOutputType value);
630  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
636  inline void SetSumHess(int nid, double sum_hess) {
637  Node& node = nodes_.at(nid);
638  node.sum_hess_ = sum_hess;
639  node.sum_hess_present_ = true;
640  }
646  inline void SetDataCount(int nid, uint64_t data_count) {
647  Node& node = nodes_.at(nid);
648  node.data_count_ = data_count;
649  node.data_count_present_ = true;
650  }
656  inline void SetGain(int nid, double gain) {
657  Node& node = nodes_.at(nid);
658  node.gain_ = gain;
659  node.gain_present_ = true;
660  }
661 };
662 
663 struct ModelParam {
685  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
701  float ratio_c;
708  float global_bias;
711  ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
712  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
713  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
714  }
715  ~ModelParam() = default;
716  ModelParam(const ModelParam&) = default;
717  ModelParam& operator=(const ModelParam&) = default;
718  ModelParam(ModelParam&&) = default;
719  ModelParam& operator=(ModelParam&&) = default;
720 
721  template<typename Container>
722  inline std::vector<std::pair<std::string, std::string>>
723  InitAllowUnknown(const Container &kwargs);
724  inline std::map<std::string, std::string> __DICT__() const;
725 };
726 
727 static_assert(std::is_standard_layout<ModelParam>::value,
728  "ModelParam must be in the standard layout");
729 
730 inline void InitParamAndCheck(ModelParam* param,
731  const std::vector<std::pair<std::string, std::string>>& cfg);
732 
734 class Model {
735  public:
737  Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
738  patch_ver_(TREELITE_VER_PATCH) {}
739  virtual ~Model() = default;
740  Model(const Model&) = delete;
741  Model& operator=(const Model&) = delete;
742  Model(Model&&) = default;
743  Model& operator=(Model&&) = default;
744 
745  template <typename ThresholdType, typename LeafOutputType>
746  inline static std::unique_ptr<Model> Create();
747  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
748  inline TypeInfo GetThresholdType() const {
749  return threshold_type_;
750  }
751  inline TypeInfo GetLeafOutputType() const {
752  return leaf_output_type_;
753  }
754  template <typename Func>
755  inline auto Dispatch(Func func);
756  template <typename Func>
757  inline auto Dispatch(Func func) const;
758 
759  virtual std::size_t GetNumTree() const = 0;
760  virtual void SetTreeLimit(std::size_t limit) = 0;
761  virtual void DumpAsJSON(std::ostream& fo, bool pretty_print) const = 0;
762 
763  inline std::string DumpAsJSON(bool pretty_print) const {
764  std::ostringstream oss;
765  DumpAsJSON(oss, pretty_print);
766  return oss.str();
767  }
768 
769  /* Compatibility Matrix:
770  +------------------+----------+----------+----------------+-----------+
771  | | To: =2.4 | To: =3.0 | To: >=3.1,<4.0 | To: >=4.0 |
772  +------------------+----------+----------+----------------+-----------+
773  | From: <2.4 | No | No | No | No |
774  | From: =2.4 | Yes | Yes | Yes | No |
775  | From: =3.0 | No | Yes | Yes | Yes |
776  | From: >=3.1,<4.0 | No | Yes | Yes | Yes |
777  | From: >=4.0 | No | No | No | Yes |
778  +------------------+----------+----------+----------------+-----------+ */
779 
780  /* In-memory serialization, zero-copy */
781  TREELITE_DLL_EXPORT std::vector<PyBufferFrame> GetPyBuffer();
782  TREELITE_DLL_EXPORT static std::unique_ptr<Model>
783  CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
784 
785  /* Serialization to a file stream */
786  void SerializeToFile(FILE* dest_fp);
787  static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
788 
793  int32_t num_feature{0};
797  bool average_tree_output{false};
799  TaskParam task_param{};
801  ModelParam param{};
802 
803  protected:
804  /* Note: the following member fields shall be re-computed at serialization time */
805  // Number of trees
806  uint64_t num_tree_{0};
807  // Number of optional fields in the extension slot
808  int32_t num_opt_field_per_model_{0};
809  // Which Treelite version produced this model
810  int32_t major_ver_;
811  int32_t minor_ver_;
812  int32_t patch_ver_;
813 
814  private:
815  TypeInfo threshold_type_{TypeInfo::kInvalid};
816  TypeInfo leaf_output_type_{TypeInfo::kInvalid};
817  // Internal functions for serialization
818  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
819  virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
820  // Load a Model object from a sequence of PyBuffer frames
821  // Returns the updated position of the cursor in the sequence
822  virtual std::vector<PyBufferFrame>::iterator InitFromPyBuffer(
823  std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) = 0;
824  virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
825  template <typename HeaderPrimitiveFieldHandlerFunc>
826  inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
827  template <typename HeaderPrimitiveFieldHandlerFunc>
828  inline static void DeserializeTemplate(
829  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
830  int32_t& major_ver, int32_t& minor_ver, int32_t& patch_ver,
831  TypeInfo& threshold_type, TypeInfo& leaf_output_type);
832 };
833 
834 template <typename ThresholdType, typename LeafOutputType>
835 class ModelImpl : public Model {
836  public:
838  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
839 
841  ModelImpl() = default;
842  ~ModelImpl() override = default;
843  ModelImpl(const ModelImpl&) = delete;
844  ModelImpl& operator=(const ModelImpl&) = delete;
845  ModelImpl(ModelImpl&&) noexcept = default;
846  ModelImpl& operator=(ModelImpl&&) noexcept = default;
847 
848  void DumpAsJSON(std::ostream& fo, bool pretty_print) const override;
849  inline std::size_t GetNumTree() const override {
850  return trees.size();
851  }
852  void SetTreeLimit(std::size_t limit) override {
853  return trees.resize(limit);
854  }
855 
856  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
857  inline void SerializeToFileImpl(FILE* dest_fp) override;
858  // Load a ModelImpl object from a sequence of PyBuffer frames
859  // Returns the updated position of the cursor in the sequence
860  inline std::vector<PyBufferFrame>::iterator InitFromPyBuffer(
861  std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) override;
862  inline void DeserializeFromFileImpl(FILE* src_fp) override;
863 
864  private:
865  template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
866  typename TreeHandlerFunc>
867  inline void SerializeTemplate(
868  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
869  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
870  TreeHandlerFunc tree_handler);
871  template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc,
872  typename SkipOptFieldHandlerFunc>
873  inline void DeserializeTemplate(
874  size_t num_tree,
875  HeaderFieldHandlerFunc header_field_handler,
876  TreeHandlerFunc tree_handler,
877  SkipOptFieldHandlerFunc skip_opt_field_handler);
878 };
879 
886 std::unique_ptr<Model> ConcatenateModelObjects(const std::vector<const Model*>& objs);
887 
888 } // namespace treelite
889 
890 #include "tree_impl.h"
891 
892 #endif // TREELITE_TREE_H_
SplitFeatureType split_type_
feature split type
Definition: tree.h:260
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:502
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:272
SplitFeatureType
feature split type
Definition: base.h:23
std::int32_t cleft_
pointer to left and right children
Definition: tree.h:235
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:537
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:566
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:192
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:181
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:266
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:472
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition: tree.h:588
TaskType
Enum type representing the task type.
Definition: tree.h:107
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:693
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:437
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:268
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:207
tree node
Definition: tree.h:225
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:636
in-memory representation of a decision tree
Definition: tree.h:222
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:254
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:200
float global_bias
global bias of the model
Definition: tree.h:708
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:258
TaskType task_type
Task type.
Definition: tree.h:795
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:23
store either leaf value or decision threshold
Definition: tree.h:230
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:838
std::uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:444
double SumHess(int nid) const
get hessian sum
Definition: tree.h:559
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:656
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:646
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:581
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:530
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:465
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:701
Model()
disable copy; use default move
Definition: tree.h:737
std::uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:247
int LeftChild(int nid) const
Getters.
Definition: tree.h:423
defines configuration macros of Treelite
std::uint64_t DataCount(int nid) const
get data count
Definition: tree.h:544
double Gain(int nid) const
get gain value
Definition: tree.h:573
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:430
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:184
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:270
thin wrapper for tree ensemble model
Definition: tree.h:734
std::uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:240
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:451
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:552
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:458
int LeftChild() const
Getters.
Definition: tree.h:278
std::unique_ptr< Model > ConcatenateModelObjects(const std::vector< const Model *> &objs)
Concatenate multiple model objects into a single model object by copying all member trees into the de...
Definition: model_concat.cc:16
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:495
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:488
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:513
Info info_
storage for leaf value or decision threshold
Definition: tree.h:242
Operator
comparison operators
Definition: base.h:27