Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/version.h>
12 #include <algorithm>
13 #include <map>
14 #include <memory>
15 #include <ostream>
16 #include <string>
17 #include <vector>
18 #include <utility>
19 #include <type_traits>
20 #include <limits>
21 #include <cstddef>
22 #include <cstdint>
23 #include <cstring>
24 #include <cstdio>
25 
26 #define __TREELITE_STR(x) #x
27 #define _TREELITE_STR(x) __TREELITE_STR(x)
28 
29 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
30 
31 namespace treelite {
32 
33 // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
34 // to hold only 1-D arrays with stride 1.
35 struct PyBufferFrame {
36  void* buf;
37  char* format;
38  std::size_t itemsize;
39  std::size_t nitem;
40 };
41 
42 static_assert(std::is_pod<PyBufferFrame>::value, "PyBufferFrame must be a POD type");
43 
44 template <typename T>
46  public:
48  ~ContiguousArray();
49  // NOTE: use Clone to make deep copy; copy constructors disabled
50  ContiguousArray(const ContiguousArray&) = delete;
51  ContiguousArray& operator=(const ContiguousArray&) = delete;
52  ContiguousArray(ContiguousArray&& other) noexcept;
53  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
54  inline ContiguousArray Clone() const;
55  inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
56  inline T* Data();
57  inline const T* Data() const;
58  inline T* End();
59  inline const T* End() const;
60  inline T& Back();
61  inline const T& Back() const;
62  inline std::size_t Size() const;
63  inline void Reserve(std::size_t newsize);
64  inline void Resize(std::size_t newsize);
65  inline void Resize(std::size_t newsize, T t);
66  inline void Clear();
67  inline void PushBack(T t);
68  inline void Extend(const std::vector<T>& other);
69  /* Unsafe access, no bounds checking */
70  inline T& operator[](std::size_t idx);
71  inline const T& operator[](std::size_t idx) const;
72  /* Safe access, with bounds checking */
73  inline T& at(std::size_t idx);
74  inline const T& at(std::size_t idx) const;
75  /* Safe access, with bounds checking + check against non-existent node (<0) */
76  inline T& at(int idx);
77  inline const T& at(int idx) const;
78  static_assert(std::is_pod<T>::value, "T must be POD");
79 
80  private:
81  T* buffer_;
82  std::size_t size_;
83  std::size_t capacity_;
84  bool owned_buffer_;
85 };
86 
93 enum class TaskType : uint8_t {
101  kBinaryClfRegr = 0,
120  kMultiClfGrovePerClass = 1,
136  kMultiClfProbDistLeaf = 2,
153  kMultiClfCategLeaf = 3
154 };
155 
157 struct TaskParam {
158  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
160  OutputType output_type;
176  unsigned int num_class;
183  unsigned int leaf_vector_size;
184 };
185 
186 static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");
187 
189 template <typename ThresholdType, typename LeafOutputType>
190 class Tree {
191  public:
193  struct Node {
196  inline void Init();
198  union Info {
199  LeafOutputType leaf_value; // for leaf nodes
200  ThresholdType threshold; // for non-leaf nodes
201  };
203  int32_t cleft_, cright_;
208  uint32_t sindex_;
215  uint64_t data_count_;
222  double sum_hess_;
226  double gain_;
241  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
242  * node or the left child node. True if the right child, False otherwise */
243  bool categories_list_right_child_;
244  };
245 
246  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
247  static_assert(std::is_same<ThresholdType, float>::value
248  || std::is_same<ThresholdType, double>::value,
249  "ThresholdType must be either float32 or float64");
250  static_assert(std::is_same<LeafOutputType, uint32_t>::value
251  || std::is_same<LeafOutputType, float>::value
252  || std::is_same<LeafOutputType, double>::value,
253  "LeafOutputType must be one of uint32_t, float32 or float64");
254  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
255  || std::is_same<LeafOutputType, uint32_t>::value,
256  "Unsupported combination of ThresholdType and LeafOutputType");
257  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
258  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
259  "Node size incorrect");
260 
261  Tree() = default;
262  ~Tree() = default;
263  Tree(const Tree&) = delete;
264  Tree& operator=(const Tree&) = delete;
265  Tree(Tree&&) noexcept = default;
266  Tree& operator=(Tree&&) noexcept = default;
267 
268  inline Tree<ThresholdType, LeafOutputType> Clone() const;
269 
270  inline const char* GetFormatStringForNode();
271  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
272  inline void SerializeToFile(FILE* dest_fp);
273  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
274  std::vector<PyBufferFrame>::iterator end);
275  inline void DeserializeFromFile(FILE* src_fp);
276 
277  private:
278  // vector of nodes
279  ContiguousArray<Node> nodes_;
280  ContiguousArray<LeafOutputType> leaf_vector_;
281  ContiguousArray<std::size_t> leaf_vector_offset_;
282  ContiguousArray<uint32_t> matching_categories_;
283  ContiguousArray<std::size_t> matching_categories_offset_;
284 
285  template <typename WriterType, typename X, typename Y>
286  friend void SerializeTreeToJSON(WriterType& writer, const Tree<X, Y>& tree);
287 
288  // allocate a new node
289  inline int AllocNode();
290 
291  // utility functions used for serialization, internal use only
292  template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
293  inline void
294  SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
295  CompositeArrayHandler composite_array_handler);
296  template <typename ScalarHandler, typename ArrayHandler>
297  inline void
298  DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
299 
300  public:
302  int num_nodes;
304  inline void Init();
309  inline void AddChilds(int nid);
310 
315  inline std::vector<unsigned> GetCategoricalFeatures() const;
316 
322  inline int LeftChild(int nid) const {
323  return nodes_.at(nid).cleft_;
324  }
329  inline int RightChild(int nid) const {
330  return nodes_.at(nid).cright_;
331  }
336  inline int DefaultChild(int nid) const {
337  return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
338  }
343  inline uint32_t SplitIndex(int nid) const {
344  return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
345  }
350  inline bool DefaultLeft(int nid) const {
351  return (nodes_.at(nid).sindex_ >> 31U) != 0;
352  }
357  inline bool IsLeaf(int nid) const {
358  return nodes_.at(nid).cleft_ == -1;
359  }
364  inline LeafOutputType LeafValue(int nid) const {
365  return (nodes_.at(nid).info_).leaf_value;
366  }
371  inline std::vector<LeafOutputType> LeafVector(int nid) const {
372  const std::size_t offset_begin = leaf_vector_offset_.at(nid);
373  const std::size_t offset_end = leaf_vector_offset_.at(nid + 1);
374  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
375  // Return empty vector, to indicate the lack of leaf vector
376  return std::vector<LeafOutputType>();
377  }
378  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
379  &leaf_vector_[offset_end]);
380  // Use unsafe access here, since we may need to take the address of one past the last
381  // element, to follow with the range semantic of std::vector<>.
382  }
387  inline bool HasLeafVector(int nid) const {
388  return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
389  }
394  inline ThresholdType Threshold(int nid) const {
395  return (nodes_.at(nid).info_).threshold;
396  }
401  inline Operator ComparisonOp(int nid) const {
402  return nodes_.at(nid).cmp_;
403  }
412  inline std::vector<uint32_t> MatchingCategories(int nid) const {
413  const std::size_t offset_begin = matching_categories_offset_.at(nid);
414  const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
415  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
416  // Return empty vector, to indicate the lack of any matching categories
417  // The node might be a numerical split
418  return std::vector<uint32_t>();
419  }
420  return std::vector<uint32_t>(&matching_categories_[offset_begin],
421  &matching_categories_[offset_end]);
422  // Use unsafe access here, since we may need to take the address of one past the last
423  // element, to follow with the range semantic of std::vector<>.
424  }
430  inline bool HasMatchingCategories(int nid) const {
431  return matching_categories_offset_.at(nid) != matching_categories_offset_.at(nid + 1);
432  }
437  inline SplitFeatureType SplitType(int nid) const {
438  return nodes_.at(nid).split_type_;
439  }
444  inline bool HasDataCount(int nid) const {
445  return nodes_.at(nid).data_count_present_;
446  }
451  inline uint64_t DataCount(int nid) const {
452  return nodes_.at(nid).data_count_;
453  }
454 
459  inline bool HasSumHess(int nid) const {
460  return nodes_.at(nid).sum_hess_present_;
461  }
466  inline double SumHess(int nid) const {
467  return nodes_.at(nid).sum_hess_;
468  }
473  inline bool HasGain(int nid) const {
474  return nodes_.at(nid).gain_present_;
475  }
480  inline double Gain(int nid) const {
481  return nodes_.at(nid).gain_;
482  }
488  inline bool CategoriesListRightChild(int nid) const {
489  return nodes_.at(nid).categories_list_right_child_;
490  }
491 
502  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
503  bool default_left, Operator cmp);
516  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
517  const std::vector<uint32_t>& categories_list,
518  bool categories_list_right_child);
524  inline void SetLeaf(int nid, LeafOutputType value);
530  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
536  inline void SetSumHess(int nid, double sum_hess) {
537  Node& node = nodes_.at(nid);
538  node.sum_hess_ = sum_hess;
539  node.sum_hess_present_ = true;
540  }
546  inline void SetDataCount(int nid, uint64_t data_count) {
547  Node& node = nodes_.at(nid);
548  node.data_count_ = data_count;
549  node.data_count_present_ = true;
550  }
556  inline void SetGain(int nid, double gain) {
557  Node& node = nodes_.at(nid);
558  node.gain_ = gain;
559  node.gain_present_ = true;
560  }
561 };
562 
563 struct ModelParam {
585  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
600  float global_bias;
603  ModelParam() : sigmoid_alpha(1.0f), global_bias(0.0f) {
604  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
605  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
606  }
607  ~ModelParam() = default;
608  ModelParam(const ModelParam&) = default;
609  ModelParam& operator=(const ModelParam&) = default;
610  ModelParam(ModelParam&&) = default;
611  ModelParam& operator=(ModelParam&&) = default;
612 
613  template<typename Container>
614  inline std::vector<std::pair<std::string, std::string>>
615  InitAllowUnknown(const Container &kwargs);
616  inline std::map<std::string, std::string> __DICT__() const;
617 };
618 
619 static_assert(std::is_standard_layout<ModelParam>::value,
620  "ModelParam must be in the standard layout");
621 
622 inline void InitParamAndCheck(ModelParam* param,
623  const std::vector<std::pair<std::string, std::string>>& cfg);
624 
626 class Model {
627  public:
629  Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
630  patch_ver_(TREELITE_VER_PATCH) {}
631  virtual ~Model() = default;
632  Model(const Model&) = delete;
633  Model& operator=(const Model&) = delete;
634  Model(Model&&) = default;
635  Model& operator=(Model&&) = default;
636 
637  template <typename ThresholdType, typename LeafOutputType>
638  inline static std::unique_ptr<Model> Create();
639  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
640  inline TypeInfo GetThresholdType() const {
641  return threshold_type_;
642  }
643  inline TypeInfo GetLeafOutputType() const {
644  return leaf_output_type_;
645  }
646  template <typename Func>
647  inline auto Dispatch(Func func);
648  template <typename Func>
649  inline auto Dispatch(Func func) const;
650 
651  virtual std::size_t GetNumTree() const = 0;
652  virtual void SetTreeLimit(std::size_t limit) = 0;
653  virtual void SerializeToJSON(std::ostream& fo) const = 0;
654 
655  /* In-memory serialization, zero-copy */
656  std::vector<PyBufferFrame> GetPyBuffer();
657  static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
658 
659  /* Serialization to a file stream */
660  void SerializeToFile(FILE* dest_fp);
661  static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
662 
676 
677  private:
678  int major_ver_, minor_ver_, patch_ver_;
679  TypeInfo threshold_type_;
680  TypeInfo leaf_output_type_;
681  // Internal functions for serialization
682  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
683  virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
684  virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
685  std::vector<PyBufferFrame>::iterator end) = 0;
686  virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
687  template <typename HeaderPrimitiveFieldHandlerFunc>
688  inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
689  template <typename HeaderPrimitiveFieldHandlerFunc>
690  inline static void DeserializeTemplate(
691  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
692  TypeInfo& threshold_type, TypeInfo& leaf_output_type);
693 };
694 
695 template <typename ThresholdType, typename LeafOutputType>
696 class ModelImpl : public Model {
697  public:
699  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
700 
702  ModelImpl() = default;
703  ~ModelImpl() override = default;
704  ModelImpl(const ModelImpl&) = delete;
705  ModelImpl& operator=(const ModelImpl&) = delete;
706  ModelImpl(ModelImpl&&) noexcept = default;
707  ModelImpl& operator=(ModelImpl&&) noexcept = default;
708 
709  void SerializeToJSON(std::ostream& fo) const override;
710  inline std::size_t GetNumTree() const override {
711  return trees.size();
712  }
713  void SetTreeLimit(std::size_t limit) override {
714  return trees.resize(limit);
715  }
716 
717  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
718  inline void SerializeToFileImpl(FILE* dest_fp) override;
719  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
720  std::vector<PyBufferFrame>::iterator end) override;
721  inline void DeserializeFromFileImpl(FILE* src_fp) override;
722 
723  private:
724  template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
725  typename TreeHandlerFunc>
726  inline void SerializeTemplate(
727  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
728  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
729  TreeHandlerFunc tree_handler);
730  template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
731  inline void DeserializeTemplate(
732  size_t num_tree,
733  HeaderFieldHandlerFunc header_field_handler,
734  TreeHandlerFunc tree_handler);
735 };
736 
737 } // namespace treelite
738 
739 #include "tree_impl.h"
740 
741 #endif // TREELITE_TREE_H_
ModelParam param
extra parameters
Definition: tree.h:675
SplitFeatureType split_type_
feature split type
Definition: tree.h:228
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:401
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:240
SplitFeatureType
feature split type
Definition: base.h:22
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:215
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:444
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:473
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:168
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:157
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:234
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:371
TaskType
Enum type representing the task type.
Definition: tree.h:93
bool average_tree_output
whether to average tree outputs
Definition: tree.h:671
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:593
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:336
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:236
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:183
tree node
Definition: tree.h:193
int32_t cleft_
pointer to left and right children
Definition: tree.h:203
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:536
in-memory representation of a decision tree
Definition: tree.h:190
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:222
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:176
float global_bias
global bias of the model
Definition: tree.h:600
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:226
TaskType task_type
Task type.
Definition: tree.h:669
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:343
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
store either leaf value or decision threshold
Definition: tree.h:198
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:699
double SumHess(int nid) const
get hessian sum
Definition: tree.h:466
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:556
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:546
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:488
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:437
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:364
Model()
disable copy; use default move
Definition: tree.h:629
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:673
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:412
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:451
double Gain(int nid) const
get gain value
Definition: tree.h:480
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:329
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:160
bool HasMatchingCategories(int nid) const
tests whether the node has a non-empty list for matching categories. See MatchingCategories() for the...
Definition: tree.h:430
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:238
thin wrapper for tree ensemble model
Definition: tree.h:626
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:350
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:459
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:357
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:667
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:208
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:394
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:387
Info info_
storage for leaf value or decision threshold
Definition: tree.h:210
Operator
comparison operators
Definition: base.h:26