Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/version.h>
12 #include <algorithm>
13 #include <map>
14 #include <memory>
15 #include <string>
16 #include <vector>
17 #include <utility>
18 #include <type_traits>
19 #include <limits>
20 #include <cstddef>
21 #include <cstdint>
22 #include <cstring>
23 #include <cstdio>
24 
25 #define __TREELITE_STR(x) #x
26 #define _TREELITE_STR(x) __TREELITE_STR(x)
27 
28 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
29 
30 /* Foward declarations */
31 namespace dmlc {
32 
33 class Stream;
34 float stof(const std::string& value, std::size_t* pos);
35 
36 } // namespace dmlc
37 
38 namespace treelite {
39 
40 // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
41 // to hold only 1-D arrays with stride 1.
42 struct PyBufferFrame {
43  void* buf;
44  char* format;
45  std::size_t itemsize;
46  std::size_t nitem;
47 };
48 
49 static_assert(std::is_pod<PyBufferFrame>::value, "PyBufferFrame must be a POD type");
50 
51 template <typename T>
53  public:
55  ~ContiguousArray();
56  // NOTE: use Clone to make deep copy; copy constructors disabled
57  ContiguousArray(const ContiguousArray&) = delete;
58  ContiguousArray& operator=(const ContiguousArray&) = delete;
59  ContiguousArray(ContiguousArray&& other) noexcept;
60  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
61  inline ContiguousArray Clone() const;
62  inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
63  inline T* Data();
64  inline const T* Data() const;
65  inline T* End();
66  inline const T* End() const;
67  inline T& Back();
68  inline const T& Back() const;
69  inline std::size_t Size() const;
70  inline void Reserve(std::size_t newsize);
71  inline void Resize(std::size_t newsize);
72  inline void Resize(std::size_t newsize, T t);
73  inline void Clear();
74  inline void PushBack(T t);
75  inline void Extend(const std::vector<T>& other);
76  /* Unsafe access, no bounds checking */
77  inline T& operator[](std::size_t idx);
78  inline const T& operator[](std::size_t idx) const;
79  /* Safe access, with bounds checking */
80  inline T& at(std::size_t idx);
81  inline const T& at(std::size_t idx) const;
82  /* Safe access, with bounds checking + check against non-existent node (<0) */
83  inline T& at(int idx);
84  inline const T& at(int idx) const;
85  static_assert(std::is_pod<T>::value, "T must be POD");
86 
87  private:
88  T* buffer_;
89  std::size_t size_;
90  std::size_t capacity_;
91  bool owned_buffer_;
92 };
93 
100 enum class TaskType : uint8_t {
108  kBinaryClfRegr = 0,
127  kMultiClfGrovePerClass = 1,
143  kMultiClfProbDistLeaf = 2,
160  kMultiClfCategLeaf = 3
161 };
162 
165  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
167  OutputType output_type;
183  unsigned int num_class;
190  unsigned int leaf_vector_size;
191 };
192 
193 static_assert(std::is_pod<TaskParameter>::value, "TaskParameter must be POD type");
194 
196 template <typename ThresholdType, typename LeafOutputType>
197 class Tree {
198  public:
200  struct Node {
203  inline void Init();
205  union Info {
206  LeafOutputType leaf_value; // for leaf nodes
207  ThresholdType threshold; // for non-leaf nodes
208  };
210  int32_t cleft_, cright_;
215  uint32_t sindex_;
222  uint64_t data_count_;
229  double sum_hess_;
233  double gain_;
248  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
249  * node or the left child node. True if the right child, False otherwise */
250  bool categories_list_right_child_;
251  };
252 
253  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
254  static_assert(std::is_same<ThresholdType, float>::value
255  || std::is_same<ThresholdType, double>::value,
256  "ThresholdType must be either float32 or float64");
257  static_assert(std::is_same<LeafOutputType, uint32_t>::value
258  || std::is_same<LeafOutputType, float>::value
259  || std::is_same<LeafOutputType, double>::value,
260  "LeafOutputType must be one of uint32_t, float32 or float64");
261  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
262  || std::is_same<LeafOutputType, uint32_t>::value,
263  "Unsupported combination of ThresholdType and LeafOutputType");
264  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
265  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
266  "Node size incorrect");
267 
268  Tree() = default;
269  ~Tree() = default;
270  Tree(const Tree&) = delete;
271  Tree& operator=(const Tree&) = delete;
272  Tree(Tree&&) noexcept = default;
273  Tree& operator=(Tree&&) noexcept = default;
274 
275  inline Tree<ThresholdType, LeafOutputType> Clone() const;
276 
277  inline const char* GetFormatStringForNode();
278  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
279  inline void SerializeToFile(FILE* dest_fp);
280  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
281  std::vector<PyBufferFrame>::iterator end);
282  inline void DeserializeFromFile(FILE* src_fp);
283 
284  private:
285  // vector of nodes
286  ContiguousArray<Node> nodes_;
287  ContiguousArray<LeafOutputType> leaf_vector_;
288  ContiguousArray<std::size_t> leaf_vector_offset_;
289  ContiguousArray<uint32_t> matching_categories_;
290  ContiguousArray<std::size_t> matching_categories_offset_;
291 
292  // allocate a new node
293  inline int AllocNode();
294 
295  // utility functions used for serialization, internal use only
296  template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
297  inline void
298  SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
299  CompositeArrayHandler composite_array_handler);
300  template <typename ScalarHandler, typename ArrayHandler>
301  inline void
302  DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
303 
304  public:
306  int num_nodes;
308  inline void Init();
313  inline void AddChilds(int nid);
314 
319  inline std::vector<unsigned> GetCategoricalFeatures() const;
320 
326  inline int LeftChild(int nid) const {
327  return nodes_.at(nid).cleft_;
328  }
333  inline int RightChild(int nid) const {
334  return nodes_.at(nid).cright_;
335  }
340  inline int DefaultChild(int nid) const {
341  return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
342  }
347  inline uint32_t SplitIndex(int nid) const {
348  return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
349  }
354  inline bool DefaultLeft(int nid) const {
355  return (nodes_.at(nid).sindex_ >> 31U) != 0;
356  }
361  inline bool IsLeaf(int nid) const {
362  return nodes_.at(nid).cleft_ == -1;
363  }
368  inline LeafOutputType LeafValue(int nid) const {
369  return (nodes_.at(nid).info_).leaf_value;
370  }
375  inline std::vector<LeafOutputType> LeafVector(int nid) const {
376  const std::size_t offset_begin = leaf_vector_offset_.at(nid);
377  const std::size_t offset_end = leaf_vector_offset_.at(nid + 1);
378  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
379  // Return empty vector, to indicate the lack of leaf vector
380  return std::vector<LeafOutputType>();
381  }
382  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
383  &leaf_vector_[offset_end]);
384  // Use unsafe access here, since we may need to take the address of one past the last
385  // element, to follow with the range semantic of std::vector<>.
386  }
391  inline bool HasLeafVector(int nid) const {
392  return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
393  }
398  inline ThresholdType Threshold(int nid) const {
399  return (nodes_.at(nid).info_).threshold;
400  }
405  inline Operator ComparisonOp(int nid) const {
406  return nodes_.at(nid).cmp_;
407  }
416  inline std::vector<uint32_t> MatchingCategories(int nid) const {
417  const std::size_t offset_begin = matching_categories_offset_.at(nid);
418  const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
419  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
420  // Return empty vector, to indicate the lack of any matching categories
421  // The node might be a numerical split
422  return std::vector<uint32_t>();
423  }
424  return std::vector<uint32_t>(&matching_categories_[offset_begin],
425  &matching_categories_[offset_end]);
426  // Use unsafe access here, since we may need to take the address of one past the last
427  // element, to follow with the range semantic of std::vector<>.
428  }
434  inline bool HasMatchingCategories(int nid) const {
435  return matching_categories_offset_.at(nid) != matching_categories_offset_.at(nid + 1);
436  }
441  inline SplitFeatureType SplitType(int nid) const {
442  return nodes_.at(nid).split_type_;
443  }
448  inline bool HasDataCount(int nid) const {
449  return nodes_.at(nid).data_count_present_;
450  }
455  inline uint64_t DataCount(int nid) const {
456  return nodes_.at(nid).data_count_;
457  }
458 
463  inline bool HasSumHess(int nid) const {
464  return nodes_.at(nid).sum_hess_present_;
465  }
470  inline double SumHess(int nid) const {
471  return nodes_.at(nid).sum_hess_;
472  }
477  inline bool HasGain(int nid) const {
478  return nodes_.at(nid).gain_present_;
479  }
484  inline double Gain(int nid) const {
485  return nodes_.at(nid).gain_;
486  }
492  inline bool CategoriesListRightChild(int nid) const {
493  return nodes_.at(nid).categories_list_right_child_;
494  }
495 
506  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
507  bool default_left, Operator cmp);
520  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
521  const std::vector<uint32_t>& categories_list,
522  bool categories_list_right_child);
528  inline void SetLeaf(int nid, LeafOutputType value);
534  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
540  inline void SetSumHess(int nid, double sum_hess) {
541  Node& node = nodes_.at(nid);
542  node.sum_hess_ = sum_hess;
543  node.sum_hess_present_ = true;
544  }
550  inline void SetDataCount(int nid, uint64_t data_count) {
551  Node& node = nodes_.at(nid);
552  node.data_count_ = data_count;
553  node.data_count_present_ = true;
554  }
560  inline void SetGain(int nid, double gain) {
561  Node& node = nodes_.at(nid);
562  node.gain_ = gain;
563  node.gain_present_ = true;
564  }
565 
566  void ReferenceSerialize(dmlc::Stream* fo) const;
567 };
568 
569 struct ModelParam {
591  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
606  float global_bias;
609  ModelParam() : sigmoid_alpha(1.0f), global_bias(0.0f) {
610  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
611  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
612  }
613  ~ModelParam() = default;
614  ModelParam(const ModelParam&) = default;
615  ModelParam& operator=(const ModelParam&) = default;
616  ModelParam(ModelParam&&) = default;
617  ModelParam& operator=(ModelParam&&) = default;
618 
619  template<typename Container>
620  inline std::vector<std::pair<std::string, std::string>>
621  InitAllowUnknown(const Container &kwargs);
622  inline std::map<std::string, std::string> __DICT__() const;
623 };
624 
625 static_assert(std::is_standard_layout<ModelParam>::value,
626  "ModelParam must be in the standard layout");
627 
628 inline void InitParamAndCheck(ModelParam* param,
629  const std::vector<std::pair<std::string, std::string>>& cfg);
630 
632 class Model {
633  public:
635  Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
636  patch_ver_(TREELITE_VER_PATCH) {}
637  virtual ~Model() = default;
638  Model(const Model&) = delete;
639  Model& operator=(const Model&) = delete;
640  Model(Model&&) = default;
641  Model& operator=(Model&&) = default;
642 
643  template <typename ThresholdType, typename LeafOutputType>
644  inline static std::unique_ptr<Model> Create();
645  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
646  inline TypeInfo GetThresholdType() const {
647  return threshold_type_;
648  }
649  inline TypeInfo GetLeafOutputType() const {
650  return leaf_output_type_;
651  }
652  template <typename Func>
653  inline auto Dispatch(Func func);
654  template <typename Func>
655  inline auto Dispatch(Func func) const;
656 
657  virtual std::size_t GetNumTree() const = 0;
658  virtual void SetTreeLimit(std::size_t limit) = 0;
659  virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0;
660 
661  /* In-memory serialization, zero-copy */
662  std::vector<PyBufferFrame> GetPyBuffer();
663  static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
664 
665  /* Serialization to a file stream */
666  void SerializeToFile(FILE* dest_fp);
667  static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
668 
682 
683  private:
684  int major_ver_, minor_ver_, patch_ver_;
685  TypeInfo threshold_type_;
686  TypeInfo leaf_output_type_;
687  // Internal functions for serialization
688  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
689  virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
690  virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
691  std::vector<PyBufferFrame>::iterator end) = 0;
692  virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
693  template <typename HeaderPrimitiveFieldHandlerFunc>
694  inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
695  template <typename HeaderPrimitiveFieldHandlerFunc>
696  inline static void DeserializeTemplate(
697  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
698  TypeInfo& threshold_type, TypeInfo& leaf_output_type);
699 };
700 
701 template <typename ThresholdType, typename LeafOutputType>
702 class ModelImpl : public Model {
703  public:
705  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
706 
708  ModelImpl() = default;
709  ~ModelImpl() override = default;
710  ModelImpl(const ModelImpl&) = delete;
711  ModelImpl& operator=(const ModelImpl&) = delete;
712  ModelImpl(ModelImpl&&) noexcept = default;
713  ModelImpl& operator=(ModelImpl&&) noexcept = default;
714 
715  void ReferenceSerialize(dmlc::Stream* fo) const override;
716  inline std::size_t GetNumTree() const override {
717  return trees.size();
718  }
719  void SetTreeLimit(std::size_t limit) override {
720  return trees.resize(limit);
721  }
722 
723  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
724  inline void SerializeToFileImpl(FILE* dest_fp) override;
725  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
726  std::vector<PyBufferFrame>::iterator end) override;
727  inline void DeserializeFromFileImpl(FILE* src_fp) override;
728 
729  private:
730  template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
731  typename TreeHandlerFunc>
732  inline void SerializeTemplate(
733  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
734  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
735  TreeHandlerFunc tree_handler);
736  template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
737  inline void DeserializeTemplate(
738  size_t num_tree,
739  HeaderFieldHandlerFunc header_field_handler,
740  TreeHandlerFunc tree_handler);
741 };
742 
743 } // namespace treelite
744 
745 #include "tree_impl.h"
746 
747 #endif // TREELITE_TREE_H_
ModelParam param
extra parameters
Definition: tree.h:681
SplitFeatureType split_type_
feature split type
Definition: tree.h:235
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:405
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:247
SplitFeatureType
feature split type
Definition: base.h:22
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:222
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:448
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:477
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:241
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:375
TaskType
Enum type representing the task type.
Definition: tree.h:100
bool average_tree_output
whether to average tree outputs
Definition: tree.h:677
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:599
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:340
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:243
tree node
Definition: tree.h:200
int32_t cleft_
pointer to left and right children
Definition: tree.h:210
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:540
in-memory representation of a decision tree
Definition: tree.h:197
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:229
float global_bias
global bias of the model
Definition: tree.h:606
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:233
TaskType task_type
Task type.
Definition: tree.h:675
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:347
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:164
store either leaf value or decision threshold
Definition: tree.h:205
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:705
Definition: tree.h:31
double SumHess(int nid) const
get hessian sum
Definition: tree.h:470
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:560
TaskParameter task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:679
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:550
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:492
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:441
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:368
Model()
disable copy; use default move
Definition: tree.h:635
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:416
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:183
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:455
double Gain(int nid) const
get gain value
Definition: tree.h:484
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:333
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:175
bool HasMatchingCategories(int nid) const
tests whether the node has a non-empty list for matching categories. See MatchingCategories() for the...
Definition: tree.h:434
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:245
thin wrapper for tree ensemble model
Definition: tree.h:632
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:354
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:463
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:361
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:167
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:673
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:215
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:190
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:398
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:391
Info info_
storage for leaf value or decision threshold
Definition: tree.h:217
Operator
comparison operators
Definition: base.h:26