Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/version.h>
12 #include <treelite/optional.h>
13 #include <algorithm>
14 #include <map>
15 #include <memory>
16 #include <ostream>
17 #include <sstream>
18 #include <string>
19 #include <vector>
20 #include <utility>
21 #include <type_traits>
22 #include <limits>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26 #include <cstdio>
27 
28 #define __TREELITE_STR(x) #x
29 #define _TREELITE_STR(x) __TREELITE_STR(x)
30 
31 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
32 
33 namespace treelite {
34 
35 template <typename ThresholdType, typename LeafOutputType>
36 class ModelImpl;
37 
38 // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
39 // to hold only 1-D arrays with stride 1.
40 struct PyBufferFrame {
41  void* buf;
42  char* format;
43  std::size_t itemsize;
44  std::size_t nitem;
45 };
46 
47 static_assert(std::is_pod<PyBufferFrame>::value, "PyBufferFrame must be a POD type");
48 
49 template <typename T>
51  public:
53  ~ContiguousArray();
54  // NOTE: use Clone to make deep copy; copy constructors disabled
55  ContiguousArray(const ContiguousArray&) = delete;
56  ContiguousArray& operator=(const ContiguousArray&) = delete;
57  ContiguousArray(ContiguousArray&& other) noexcept;
58  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
59  inline ContiguousArray Clone() const;
60  inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
61  inline T* Data();
62  inline const T* Data() const;
63  inline T* End();
64  inline const T* End() const;
65  inline T& Back();
66  inline const T& Back() const;
67  inline std::size_t Size() const;
68  inline bool Empty() const;
69  inline void Reserve(std::size_t newsize);
70  inline void Resize(std::size_t newsize);
71  inline void Resize(std::size_t newsize, T t);
72  inline void Clear();
73  inline void PushBack(T t);
74  inline void Extend(const std::vector<T>& other);
75  /* Unsafe access, no bounds checking */
76  inline T& operator[](std::size_t idx);
77  inline const T& operator[](std::size_t idx) const;
78  /* Safe access, with bounds checking */
79  inline T& at(std::size_t idx);
80  inline const T& at(std::size_t idx) const;
81  /* Safe access, with bounds checking + check against non-existent node (<0) */
82  inline T& at(int idx);
83  inline const T& at(int idx) const;
84  static_assert(std::is_pod<T>::value, "T must be POD");
85 
86  private:
87  T* buffer_;
88  std::size_t size_;
89  std::size_t capacity_;
90  bool owned_buffer_;
91 };
92 
99 enum class TaskType : uint8_t {
107  kBinaryClfRegr = 0,
126  kMultiClfGrovePerClass = 1,
142  kMultiClfProbDistLeaf = 2,
159  kMultiClfCategLeaf = 3
160 };
161 
162 inline std::string TaskTypeToString(TaskType type) {
163  switch (type) {
164  case TaskType::kBinaryClfRegr: return "BinaryClfRegr";
165  case TaskType::kMultiClfGrovePerClass: return "MultiClfGrovePerClass";
166  case TaskType::kMultiClfProbDistLeaf: return "MultiClfProbDistLeaf";
167  case TaskType::kMultiClfCategLeaf: return "MultiClfCategLeaf";
168  default: return "";
169  }
170 }
171 
173 struct TaskParam {
174  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
176  OutputType output_type;
192  unsigned int num_class;
199  unsigned int leaf_vector_size;
200 };
201 
202 inline std::string OutputTypeToString(TaskParam::OutputType type) {
203  switch (type) {
204  case TaskParam::OutputType::kFloat: return "float";
205  case TaskParam::OutputType::kInt: return "int";
206  default: return "";
207  }
208 }
209 
210 static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");
211 
213 template <typename ThresholdType, typename LeafOutputType>
214 class Tree {
215  public:
217  struct Node {
220  inline void Init();
222  union Info {
223  LeafOutputType leaf_value; // for leaf nodes
224  ThresholdType threshold; // for non-leaf nodes
225  };
227  int32_t cleft_, cright_;
232  uint32_t sindex_;
239  uint64_t data_count_;
246  double sum_hess_;
250  double gain_;
265  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
266  * node or the left child node. True if the right child, False otherwise */
267  bool categories_list_right_child_;
268  };
269 
270  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
271  static_assert(std::is_same<ThresholdType, float>::value
272  || std::is_same<ThresholdType, double>::value,
273  "ThresholdType must be either float32 or float64");
274  static_assert(std::is_same<LeafOutputType, uint32_t>::value
275  || std::is_same<LeafOutputType, float>::value
276  || std::is_same<LeafOutputType, double>::value,
277  "LeafOutputType must be one of uint32_t, float32 or float64");
278  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
279  || std::is_same<LeafOutputType, uint32_t>::value,
280  "Unsupported combination of ThresholdType and LeafOutputType");
281  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
282  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
283  "Node size incorrect");
284 
285  Tree() = default;
286  ~Tree() = default;
287  Tree(const Tree&) = delete;
288  Tree& operator=(const Tree&) = delete;
289  Tree(Tree&&) noexcept = default;
290  Tree& operator=(Tree&&) noexcept = default;
291 
292  inline Tree<ThresholdType, LeafOutputType> Clone() const;
293 
294  inline const char* GetFormatStringForNode();
295  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
296  inline void SerializeToFile(FILE* dest_fp);
297  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
298  std::vector<PyBufferFrame>::iterator end);
299  inline void DeserializeFromFile(FILE* src_fp);
300 
301  private:
302  // vector of nodes
303  ContiguousArray<Node> nodes_;
304  ContiguousArray<LeafOutputType> leaf_vector_;
305  // Map nid to the start and end index in leaf_vector_
306  // We could use std::pair, but it is not POD, so easier to use two vectors
307  // here
308  ContiguousArray<std::size_t> leaf_vector_begin_;
309  ContiguousArray<std::size_t> leaf_vector_end_;
310  ContiguousArray<uint32_t> matching_categories_;
311  ContiguousArray<std::size_t> matching_categories_offset_;
312 
313  template <typename WriterType, typename X, typename Y>
314  friend void DumpModelAsJSON(WriterType& writer, const ModelImpl<X, Y>& model);
315  template <typename WriterType, typename X, typename Y>
316  friend void DumpTreeAsJSON(WriterType& writer, const Tree<X, Y>& tree);
317 
318  // allocate a new node
319  inline int AllocNode();
320 
321  // utility functions used for serialization, internal use only
322  template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
323  inline void
324  SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
325  CompositeArrayHandler composite_array_handler);
326  template <typename ScalarHandler, typename ArrayHandler>
327  inline void
328  DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
329 
330  public:
332  int num_nodes;
334  inline void Init();
339  inline void AddChilds(int nid);
340 
345  inline std::vector<unsigned> GetCategoricalFeatures() const;
346 
352  inline int LeftChild(int nid) const {
353  return nodes_.at(nid).cleft_;
354  }
359  inline int RightChild(int nid) const {
360  return nodes_.at(nid).cright_;
361  }
366  inline int DefaultChild(int nid) const {
367  return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
368  }
373  inline uint32_t SplitIndex(int nid) const {
374  return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
375  }
380  inline bool DefaultLeft(int nid) const {
381  return (nodes_.at(nid).sindex_ >> 31U) != 0;
382  }
387  inline bool IsLeaf(int nid) const {
388  return nodes_.at(nid).cleft_ == -1;
389  }
394  inline LeafOutputType LeafValue(int nid) const {
395  return (nodes_.at(nid).info_).leaf_value;
396  }
401  inline std::vector<LeafOutputType> LeafVector(int nid) const {
402  const std::size_t offset_begin = leaf_vector_begin_.at(nid);
403  const std::size_t offset_end = leaf_vector_end_.at(nid);
404  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
405  // Return empty vector, to indicate the lack of leaf vector
406  return std::vector<LeafOutputType>();
407  }
408  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
409  &leaf_vector_[offset_end]);
410  // Use unsafe access here, since we may need to take the address of one past the last
411  // element, to follow with the range semantic of std::vector<>.
412  }
417  inline bool HasLeafVector(int nid) const {
418  return leaf_vector_begin_.at(nid) != leaf_vector_end_.at(nid);
419  }
424  inline ThresholdType Threshold(int nid) const {
425  return (nodes_.at(nid).info_).threshold;
426  }
431  inline Operator ComparisonOp(int nid) const {
432  return nodes_.at(nid).cmp_;
433  }
442  inline std::vector<uint32_t> MatchingCategories(int nid) const {
443  const std::size_t offset_begin = matching_categories_offset_.at(nid);
444  const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
445  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
446  // Return empty vector, to indicate the lack of any matching categories
447  // The node might be a numerical split
448  return std::vector<uint32_t>();
449  }
450  return std::vector<uint32_t>(&matching_categories_[offset_begin],
451  &matching_categories_[offset_end]);
452  // Use unsafe access here, since we may need to take the address of one past the last
453  // element, to follow with the range semantic of std::vector<>.
454  }
460  if (matching_categories_.Empty()) {
461  return optional<uint32_t>{};
462  }
463  return optional<uint32_t>{*std::max_element(matching_categories_.Data(),
464  matching_categories_.End())};
465  }
470  inline SplitFeatureType SplitType(int nid) const {
471  return nodes_.at(nid).split_type_;
472  }
477  inline bool HasDataCount(int nid) const {
478  return nodes_.at(nid).data_count_present_;
479  }
484  inline uint64_t DataCount(int nid) const {
485  return nodes_.at(nid).data_count_;
486  }
487 
492  inline bool HasSumHess(int nid) const {
493  return nodes_.at(nid).sum_hess_present_;
494  }
499  inline double SumHess(int nid) const {
500  return nodes_.at(nid).sum_hess_;
501  }
506  inline bool HasGain(int nid) const {
507  return nodes_.at(nid).gain_present_;
508  }
513  inline double Gain(int nid) const {
514  return nodes_.at(nid).gain_;
515  }
521  inline bool CategoriesListRightChild(int nid) const {
522  return nodes_.at(nid).categories_list_right_child_;
523  }
524 
535  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
536  bool default_left, Operator cmp);
549  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
550  const std::vector<uint32_t>& categories_list,
551  bool categories_list_right_child);
557  inline void SetLeaf(int nid, LeafOutputType value);
563  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
569  inline void SetSumHess(int nid, double sum_hess) {
570  Node& node = nodes_.at(nid);
571  node.sum_hess_ = sum_hess;
572  node.sum_hess_present_ = true;
573  }
579  inline void SetDataCount(int nid, uint64_t data_count) {
580  Node& node = nodes_.at(nid);
581  node.data_count_ = data_count;
582  node.data_count_present_ = true;
583  }
589  inline void SetGain(int nid, double gain) {
590  Node& node = nodes_.at(nid);
591  node.gain_ = gain;
592  node.gain_present_ = true;
593  }
594 };
595 
596 struct ModelParam {
618  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
634  float ratio_c;
641  float global_bias;
644  ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
645  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
646  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
647  }
648  ~ModelParam() = default;
649  ModelParam(const ModelParam&) = default;
650  ModelParam& operator=(const ModelParam&) = default;
651  ModelParam(ModelParam&&) = default;
652  ModelParam& operator=(ModelParam&&) = default;
653 
654  template<typename Container>
655  inline std::vector<std::pair<std::string, std::string>>
656  InitAllowUnknown(const Container &kwargs);
657  inline std::map<std::string, std::string> __DICT__() const;
658 };
659 
660 static_assert(std::is_standard_layout<ModelParam>::value,
661  "ModelParam must be in the standard layout");
662 
663 inline void InitParamAndCheck(ModelParam* param,
664  const std::vector<std::pair<std::string, std::string>>& cfg);
665 
667 class Model {
668  public:
670  Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
671  patch_ver_(TREELITE_VER_PATCH) {}
672  virtual ~Model() = default;
673  Model(const Model&) = delete;
674  Model& operator=(const Model&) = delete;
675  Model(Model&&) = default;
676  Model& operator=(Model&&) = default;
677 
678  template <typename ThresholdType, typename LeafOutputType>
679  inline static std::unique_ptr<Model> Create();
680  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
681  inline TypeInfo GetThresholdType() const {
682  return threshold_type_;
683  }
684  inline TypeInfo GetLeafOutputType() const {
685  return leaf_output_type_;
686  }
687  template <typename Func>
688  inline auto Dispatch(Func func);
689  template <typename Func>
690  inline auto Dispatch(Func func) const;
691 
692  virtual std::size_t GetNumTree() const = 0;
693  virtual void SetTreeLimit(std::size_t limit) = 0;
694  virtual void DumpAsJSON(std::ostream& fo, bool pretty_print) const = 0;
695 
696  inline std::string DumpAsJSON(bool pretty_print) const {
697  std::ostringstream oss;
698  DumpAsJSON(oss, pretty_print);
699  return oss.str();
700  }
701 
702  /* In-memory serialization, zero-copy */
703  std::vector<PyBufferFrame> GetPyBuffer();
704  static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
705 
706  /* Serialization to a file stream */
707  void SerializeToFile(FILE* dest_fp);
708  static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
709 
723 
724  private:
725  int major_ver_, minor_ver_, patch_ver_;
726  TypeInfo threshold_type_;
727  TypeInfo leaf_output_type_;
728  // Internal functions for serialization
729  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
730  virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
731  virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
732  std::vector<PyBufferFrame>::iterator end) = 0;
733  virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
734  template <typename HeaderPrimitiveFieldHandlerFunc>
735  inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
736  template <typename HeaderPrimitiveFieldHandlerFunc>
737  inline static void DeserializeTemplate(
738  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
739  TypeInfo& threshold_type, TypeInfo& leaf_output_type);
740 };
741 
742 template <typename ThresholdType, typename LeafOutputType>
743 class ModelImpl : public Model {
744  public:
746  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
747 
749  ModelImpl() = default;
750  ~ModelImpl() override = default;
751  ModelImpl(const ModelImpl&) = delete;
752  ModelImpl& operator=(const ModelImpl&) = delete;
753  ModelImpl(ModelImpl&&) noexcept = default;
754  ModelImpl& operator=(ModelImpl&&) noexcept = default;
755 
756  void DumpAsJSON(std::ostream& fo, bool pretty_print) const override;
757  inline std::size_t GetNumTree() const override {
758  return trees.size();
759  }
760  void SetTreeLimit(std::size_t limit) override {
761  return trees.resize(limit);
762  }
763 
764  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
765  inline void SerializeToFileImpl(FILE* dest_fp) override;
766  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
767  std::vector<PyBufferFrame>::iterator end) override;
768  inline void DeserializeFromFileImpl(FILE* src_fp) override;
769 
770  private:
771  template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
772  typename TreeHandlerFunc>
773  inline void SerializeTemplate(
774  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
775  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
776  TreeHandlerFunc tree_handler);
777  template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
778  inline void DeserializeTemplate(
779  size_t num_tree,
780  HeaderFieldHandlerFunc header_field_handler,
781  TreeHandlerFunc tree_handler);
782 };
783 
784 } // namespace treelite
785 
786 #include "tree_impl.h"
787 
788 #endif // TREELITE_TREE_H_
ModelParam param
extra parameters
Definition: tree.h:722
SplitFeatureType split_type_
feature split type
Definition: tree.h:252
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:431
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:264
SplitFeatureType
feature split type
Definition: base.h:22
Backport of std::optional from C++17.
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:239
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:477
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:506
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:184
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:173
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:258
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:401
TaskType
Enum type representing the task type.
Definition: tree.h:99
bool average_tree_output
whether to average tree outputs
Definition: tree.h:718
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:626
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:366
optional< uint32_t > MaxCategory() const
Get the largest category value used in all categorical splits in this tree. If there are no categoric...
Definition: tree.h:459
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:260
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:199
tree node
Definition: tree.h:217
int32_t cleft_
pointer to left and right children
Definition: tree.h:227
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:569
in-memory representation of a decision tree
Definition: tree.h:214
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:246
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:192
float global_bias
global bias of the model
Definition: tree.h:641
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:250
TaskType task_type
Task type.
Definition: tree.h:716
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:373
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
store either leaf value or decision threshold
Definition: tree.h:222
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:746
double SumHess(int nid) const
get hessian sum
Definition: tree.h:499
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:589
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:579
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:521
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:470
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:394
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:634
Model()
disable copy; use default move
Definition: tree.h:670
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:720
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:442
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:484
double Gain(int nid) const
get gain value
Definition: tree.h:513
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:359
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:176
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:262
thin wrapper for tree ensemble model
Definition: tree.h:667
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:380
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:492
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:387
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:714
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:232
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:424
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:417
Info info_
storage for leaf value or decision threshold
Definition: tree.h:234
Operator
comparison operators
Definition: base.h:26