Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <algorithm>
12 #include <map>
13 #include <memory>
14 #include <string>
15 #include <vector>
16 #include <utility>
17 #include <type_traits>
18 #include <limits>
19 #include <cstdint>
20 #include <cstring>
21 #include <cstdio>
22 
23 #define __TREELITE_STR(x) #x
24 #define _TREELITE_STR(x) __TREELITE_STR(x)
25 
26 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
27 
28 /* Foward declarations */
29 namespace dmlc {
30 
31 class Stream;
32 float stof(const std::string& value, size_t* pos);
33 
34 } // namespace dmlc
35 
36 namespace treelite {
37 
38 struct PyBufferFrame {
39  void* buf;
40  char* format;
41  size_t itemsize;
42  size_t nitem;
43 };
44 
45 template <typename T>
47  public:
49  ~ContiguousArray();
50  // NOTE: use Clone to make deep copy; copy constructors disabled
51  ContiguousArray(const ContiguousArray&) = delete;
52  ContiguousArray& operator=(const ContiguousArray&) = delete;
53  ContiguousArray(ContiguousArray&& other) noexcept;
54  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
55  inline ContiguousArray Clone() const;
56  inline void UseForeignBuffer(void* prealloc_buf, size_t size);
57  inline T* Data();
58  inline const T* Data() const;
59  inline T* End();
60  inline const T* End() const;
61  inline T& Back();
62  inline const T& Back() const;
63  inline size_t Size() const;
64  inline void Reserve(size_t newsize);
65  inline void Resize(size_t newsize);
66  inline void Resize(size_t newsize, T t);
67  inline void Clear();
68  inline void PushBack(T t);
69  inline void Extend(const std::vector<T>& other);
70  /* Unsafe access, no bounds checking */
71  inline T& operator[](size_t idx);
72  inline const T& operator[](size_t idx) const;
73  /* Safe access, with bounds checking */
74  inline T& at(size_t idx);
75  inline const T& at(size_t idx) const;
76  /* Safe access, with bounds checking + check against non-existent node (<0) */
77  inline T& at(int idx);
78  inline const T& at(int idx) const;
79  static_assert(std::is_pod<T>::value, "T must be POD");
80 
81  private:
82  T* buffer_;
83  size_t size_;
84  size_t capacity_;
85  bool owned_buffer_;
86 };
87 
94 enum class TaskType : uint8_t {
102  kBinaryClfRegr = 0,
121  kMultiClfGrovePerClass = 1,
137  kMultiClfProbDistLeaf = 2,
154  kMultiClfCategLeaf = 3
155 };
156 
159  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
161  OutputType output_type;
177  unsigned int num_class;
184  unsigned int leaf_vector_size;
185 };
186 
187 static_assert(std::is_pod<TaskParameter>::value, "TaskParameter must be POD type");
188 
190 template <typename ThresholdType, typename LeafOutputType>
191 class Tree {
192  public:
194  struct Node {
197  inline void Init();
199  union Info {
200  LeafOutputType leaf_value; // for leaf nodes
201  ThresholdType threshold; // for non-leaf nodes
202  };
204  int32_t cleft_, cright_;
209  uint32_t sindex_;
216  uint64_t data_count_;
223  double sum_hess_;
227  double gain_;
242  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
243  * node or the left child node. True if the right child, False otherwise */
244  bool categories_list_right_child_;
245  };
246 
247  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
248  static_assert(std::is_same<ThresholdType, float>::value
249  || std::is_same<ThresholdType, double>::value,
250  "ThresholdType must be either float32 or float64");
251  static_assert(std::is_same<LeafOutputType, uint32_t>::value
252  || std::is_same<LeafOutputType, float>::value
253  || std::is_same<LeafOutputType, double>::value,
254  "LeafOutputType must be one of uint32_t, float32 or float64");
255  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
256  || std::is_same<LeafOutputType, uint32_t>::value,
257  "Unsupported combination of ThresholdType and LeafOutputType");
258  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
259  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
260  "Node size incorrect");
261 
262  Tree() = default;
263  ~Tree() = default;
264  Tree(const Tree&) = delete;
265  Tree& operator=(const Tree&) = delete;
266  Tree(Tree&&) noexcept = default;
267  Tree& operator=(Tree&&) noexcept = default;
268 
269  inline Tree<ThresholdType, LeafOutputType> Clone() const;
270 
271  inline const char* GetFormatStringForNode();
272  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
273  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
274  std::vector<PyBufferFrame>::iterator end);
275 
276  private:
277  // vector of nodes
278  ContiguousArray<Node> nodes_;
279  ContiguousArray<LeafOutputType> leaf_vector_;
280  ContiguousArray<size_t> leaf_vector_offset_;
281  ContiguousArray<uint32_t> matching_categories_;
282  ContiguousArray<size_t> matching_categories_offset_;
283 
284  // allocate a new node
285  inline int AllocNode();
286 
287  public:
289  int num_nodes;
291  inline void Init();
296  inline void AddChilds(int nid);
297 
302  inline std::vector<unsigned> GetCategoricalFeatures() const;
303 
309  inline int LeftChild(int nid) const {
310  return nodes_.at(nid).cleft_;
311  }
316  inline int RightChild(int nid) const {
317  return nodes_.at(nid).cright_;
318  }
323  inline int DefaultChild(int nid) const {
324  return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
325  }
330  inline uint32_t SplitIndex(int nid) const {
331  return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
332  }
337  inline bool DefaultLeft(int nid) const {
338  return (nodes_.at(nid).sindex_ >> 31U) != 0;
339  }
344  inline bool IsLeaf(int nid) const {
345  return nodes_.at(nid).cleft_ == -1;
346  }
351  inline LeafOutputType LeafValue(int nid) const {
352  return (nodes_.at(nid).info_).leaf_value;
353  }
358  inline std::vector<LeafOutputType> LeafVector(int nid) const {
359  const size_t offset_begin = leaf_vector_offset_.at(nid);
360  const size_t offset_end = leaf_vector_offset_.at(nid + 1);
361  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
362  // Return empty vector, to indicate the lack of leaf vector
363  return std::vector<LeafOutputType>();
364  }
365  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
366  &leaf_vector_[offset_end]);
367  // Use unsafe access here, since we may need to take the address of one past the last
368  // element, to follow with the range semantic of std::vector<>.
369  }
374  inline bool HasLeafVector(int nid) const {
375  return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
376  }
381  inline ThresholdType Threshold(int nid) const {
382  return (nodes_.at(nid).info_).threshold;
383  }
388  inline Operator ComparisonOp(int nid) const {
389  return nodes_.at(nid).cmp_;
390  }
399  inline std::vector<uint32_t> MatchingCategories(int nid) const {
400  const size_t offset_begin = matching_categories_offset_.at(nid);
401  const size_t offset_end = matching_categories_offset_.at(nid + 1);
402  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
403  // Return empty vector, to indicate the lack of any matching categories
404  // The node might be a numerical split
405  return std::vector<uint32_t>();
406  }
407  return std::vector<uint32_t>(&matching_categories_[offset_begin],
408  &matching_categories_[offset_end]);
409  // Use unsafe access here, since we may need to take the address of one past the last
410  // element, to follow with the range semantic of std::vector<>.
411  }
417  inline bool HasMatchingCategories(int nid) const {
418  return matching_categories_offset_.at(nid) != matching_categories_offset_.at(nid + 1);
419  }
424  inline SplitFeatureType SplitType(int nid) const {
425  return nodes_.at(nid).split_type_;
426  }
431  inline bool HasDataCount(int nid) const {
432  return nodes_.at(nid).data_count_present_;
433  }
438  inline uint64_t DataCount(int nid) const {
439  return nodes_.at(nid).data_count_;
440  }
441 
446  inline bool HasSumHess(int nid) const {
447  return nodes_.at(nid).sum_hess_present_;
448  }
453  inline double SumHess(int nid) const {
454  return nodes_.at(nid).sum_hess_;
455  }
460  inline bool HasGain(int nid) const {
461  return nodes_.at(nid).gain_present_;
462  }
467  inline double Gain(int nid) const {
468  return nodes_.at(nid).gain_;
469  }
475  inline bool CategoriesListRightChild(int nid) const {
476  return nodes_.at(nid).categories_list_right_child_;
477  }
478 
489  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
490  bool default_left, Operator cmp);
503  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
504  const std::vector<uint32_t>& categories_list,
505  bool categories_list_right_child);
511  inline void SetLeaf(int nid, LeafOutputType value);
517  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
523  inline void SetSumHess(int nid, double sum_hess) {
524  Node& node = nodes_.at(nid);
525  node.sum_hess_ = sum_hess;
526  node.sum_hess_present_ = true;
527  }
533  inline void SetDataCount(int nid, uint64_t data_count) {
534  Node& node = nodes_.at(nid);
535  node.data_count_ = data_count;
536  node.data_count_present_ = true;
537  }
543  inline void SetGain(int nid, double gain) {
544  Node& node = nodes_.at(nid);
545  node.gain_ = gain;
546  node.gain_present_ = true;
547  }
548 
549  void ReferenceSerialize(dmlc::Stream* fo) const;
550 };
551 
552 struct ModelParam {
574  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
589  float global_bias;
592  ModelParam() : sigmoid_alpha(1.0f), global_bias(0.0f) {
593  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
594  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
595  }
596  ~ModelParam() = default;
597  ModelParam(const ModelParam&) = default;
598  ModelParam& operator=(const ModelParam&) = default;
599  ModelParam(ModelParam&&) = default;
600  ModelParam& operator=(ModelParam&&) = default;
601 
602  template<typename Container>
603  inline std::vector<std::pair<std::string, std::string>>
604  InitAllowUnknown(const Container &kwargs);
605  inline std::map<std::string, std::string> __DICT__() const;
606 };
607 
608 static_assert(std::is_standard_layout<ModelParam>::value,
609  "ModelParam must be in the standard layout");
610 
611 inline void InitParamAndCheck(ModelParam* param,
612  const std::vector<std::pair<std::string, std::string>>& cfg);
613 
615 class Model {
616  public:
618  Model() = default;
619  virtual ~Model() = default;
620  Model(const Model&) = delete;
621  Model& operator=(const Model&) = delete;
622  Model(Model&&) = default;
623  Model& operator=(Model&&) = default;
624 
625  template <typename ThresholdType, typename LeafOutputType>
626  inline static std::unique_ptr<Model> Create();
627  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
628  inline TypeInfo GetThresholdType() const {
629  return threshold_type_;
630  }
631  inline TypeInfo GetLeafOutputType() const {
632  return leaf_output_type_;
633  }
634  template <typename Func>
635  inline auto Dispatch(Func func);
636  template <typename Func>
637  inline auto Dispatch(Func func) const;
638 
639  virtual size_t GetNumTree() const = 0;
640  virtual void SetTreeLimit(size_t limit) = 0;
641  virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0;
642 
643  inline std::vector<PyBufferFrame> GetPyBuffer();
644  inline static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
645 
659 
660  private:
661  TypeInfo threshold_type_;
662  TypeInfo leaf_output_type_;
663  // Internal functions for serialization
664  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
665  virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
666  std::vector<PyBufferFrame>::iterator end) = 0;
667 };
668 
669 template <typename ThresholdType, typename LeafOutputType>
670 class ModelImpl : public Model {
671  public:
673  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
674 
676  ModelImpl() = default;
677  ~ModelImpl() override = default;
678  ModelImpl(const ModelImpl&) = delete;
679  ModelImpl& operator=(const ModelImpl&) = delete;
680  ModelImpl(ModelImpl&&) noexcept = default;
681  ModelImpl& operator=(ModelImpl&&) noexcept = default;
682 
683  void ReferenceSerialize(dmlc::Stream* fo) const override;
684  inline size_t GetNumTree() const override {
685  return trees.size();
686  }
687  void SetTreeLimit(size_t limit) override {
688  return trees.resize(limit);
689  }
690 
691  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
692  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
693  std::vector<PyBufferFrame>::iterator end) override;
694 };
695 
696 } // namespace treelite
697 
698 #include "tree_impl.h"
699 
700 #endif // TREELITE_TREE_H_
ModelParam param
extra parameters
Definition: tree.h:658
SplitFeatureType split_type_
feature split type
Definition: tree.h:229
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:388
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:241
SplitFeatureType
feature split type
Definition: base.h:22
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:216
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:431
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:460
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:235
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:358
TaskType
Enum type representing the task type.
Definition: tree.h:94
bool average_tree_output
whether to average tree outputs
Definition: tree.h:654
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:582
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:323
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:237
tree node
Definition: tree.h:194
int32_t cleft_
pointer to left and right children
Definition: tree.h:204
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:523
in-memory representation of a decision tree
Definition: tree.h:191
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:223
float global_bias
global bias of the model
Definition: tree.h:589
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:227
TaskType task_type
Task type.
Definition: tree.h:652
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:330
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:158
store either leaf value or decision threshold
Definition: tree.h:199
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:673
Definition: tree.h:29
double SumHess(int nid) const
get hessian sum
Definition: tree.h:453
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:543
TaskParameter task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:656
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:533
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:475
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:424
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:351
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:399
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:177
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:438
double Gain(int nid) const
get gain value
Definition: tree.h:467
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:316
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:169
bool HasMatchingCategories(int nid) const
tests whether the node has a non-empty list for matching categories. See MatchingCategories() for the...
Definition: tree.h:417
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:239
thin wrapper for tree ensemble model
Definition: tree.h:615
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:337
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:446
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:344
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:161
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:650
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:209
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:184
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:381
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:374
Info info_
storage for leaf value or decision threshold
Definition: tree.h:211
Operator
comparison operators
Definition: base.h:26