Treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
10 #include <treelite/base.h>
11 #include <treelite/version.h>
12 #include <algorithm>
13 #include <map>
14 #include <memory>
15 #include <ostream>
16 #include <sstream>
17 #include <string>
18 #include <vector>
19 #include <utility>
20 #include <type_traits>
21 #include <limits>
22 #include <cstddef>
23 #include <cstdint>
24 #include <cstring>
25 #include <cstdio>
26 
27 #define __TREELITE_STR(x) #x
28 #define _TREELITE_STR(x) __TREELITE_STR(x)
29 
30 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256
31 
32 namespace treelite {
33 
34 template <typename ThresholdType, typename LeafOutputType>
35 class ModelImpl;
36 
37 // Represent a frame in the Python buffer protocol (PEP 3118). We use a simplified representation
38 // to hold only 1-D arrays with stride 1.
39 struct PyBufferFrame {
40  void* buf;
41  char* format;
42  std::size_t itemsize;
43  std::size_t nitem;
44 };
45 
46 static_assert(std::is_pod<PyBufferFrame>::value, "PyBufferFrame must be a POD type");
47 
48 template <typename T>
50  public:
52  ~ContiguousArray();
53  // NOTE: use Clone to make deep copy; copy constructors disabled
54  ContiguousArray(const ContiguousArray&) = delete;
55  ContiguousArray& operator=(const ContiguousArray&) = delete;
56  ContiguousArray(ContiguousArray&& other) noexcept;
57  ContiguousArray& operator=(ContiguousArray&& other) noexcept;
58  inline ContiguousArray Clone() const;
59  inline void UseForeignBuffer(void* prealloc_buf, std::size_t size);
60  inline T* Data();
61  inline const T* Data() const;
62  inline T* End();
63  inline const T* End() const;
64  inline T& Back();
65  inline const T& Back() const;
66  inline std::size_t Size() const;
67  inline bool Empty() const;
68  inline void Reserve(std::size_t newsize);
69  inline void Resize(std::size_t newsize);
70  inline void Resize(std::size_t newsize, T t);
71  inline void Clear();
72  inline void PushBack(T t);
73  inline void Extend(const std::vector<T>& other);
74  /* Unsafe access, no bounds checking */
75  inline T& operator[](std::size_t idx);
76  inline const T& operator[](std::size_t idx) const;
77  /* Safe access, with bounds checking */
78  inline T& at(std::size_t idx);
79  inline const T& at(std::size_t idx) const;
80  /* Safe access, with bounds checking + check against non-existent node (<0) */
81  inline T& at(int idx);
82  inline const T& at(int idx) const;
83  static_assert(std::is_pod<T>::value, "T must be POD");
84 
85  private:
86  T* buffer_;
87  std::size_t size_;
88  std::size_t capacity_;
89  bool owned_buffer_;
90 };
91 
98 enum class TaskType : uint8_t {
106  kBinaryClfRegr = 0,
125  kMultiClfGrovePerClass = 1,
141  kMultiClfProbDistLeaf = 2,
158  kMultiClfCategLeaf = 3
159 };
160 
161 inline std::string TaskTypeToString(TaskType type) {
162  switch (type) {
163  case TaskType::kBinaryClfRegr: return "BinaryClfRegr";
164  case TaskType::kMultiClfGrovePerClass: return "MultiClfGrovePerClass";
165  case TaskType::kMultiClfProbDistLeaf: return "MultiClfProbDistLeaf";
166  case TaskType::kMultiClfCategLeaf: return "MultiClfCategLeaf";
167  default: return "";
168  }
169 }
170 
172 struct TaskParam {
173  enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
175  OutputType output_type;
191  unsigned int num_class;
198  unsigned int leaf_vector_size;
199 };
200 
201 inline std::string OutputTypeToString(TaskParam::OutputType type) {
202  switch (type) {
203  case TaskParam::OutputType::kFloat: return "float";
204  case TaskParam::OutputType::kInt: return "int";
205  default: return "";
206  }
207 }
208 
209 static_assert(std::is_pod<TaskParam>::value, "TaskParameter must be POD type");
210 
212 template <typename ThresholdType, typename LeafOutputType>
213 class Tree {
214  public:
216  struct Node {
219  inline void Init();
221  union Info {
222  LeafOutputType leaf_value; // for leaf nodes
223  ThresholdType threshold; // for non-leaf nodes
224  };
226  int32_t cleft_, cright_;
231  uint32_t sindex_;
238  uint64_t data_count_;
245  double sum_hess_;
249  double gain_;
264  /* \brief whether the list given by MatchingCategories(nid) is associated with the right child
265  * node or the left child node. True if the right child, False otherwise */
266  bool categories_list_right_child_;
267  };
268 
269  static_assert(std::is_pod<Node>::value, "Node must be a POD type");
270  static_assert(std::is_same<ThresholdType, float>::value
271  || std::is_same<ThresholdType, double>::value,
272  "ThresholdType must be either float32 or float64");
273  static_assert(std::is_same<LeafOutputType, uint32_t>::value
274  || std::is_same<LeafOutputType, float>::value
275  || std::is_same<LeafOutputType, double>::value,
276  "LeafOutputType must be one of uint32_t, float32 or float64");
277  static_assert(std::is_same<ThresholdType, LeafOutputType>::value
278  || std::is_same<LeafOutputType, uint32_t>::value,
279  "Unsupported combination of ThresholdType and LeafOutputType");
280  static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
281  || (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
282  "Node size incorrect");
283 
284  Tree() = default;
285  ~Tree() = default;
286  Tree(const Tree&) = delete;
287  Tree& operator=(const Tree&) = delete;
288  Tree(Tree&&) noexcept = default;
289  Tree& operator=(Tree&&) noexcept = default;
290 
291  inline Tree<ThresholdType, LeafOutputType> Clone() const;
292 
293  inline const char* GetFormatStringForNode();
294  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
295  inline void SerializeToFile(FILE* dest_fp);
296  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
297  std::vector<PyBufferFrame>::iterator end);
298  inline void DeserializeFromFile(FILE* src_fp);
299 
300  private:
301  // vector of nodes
302  ContiguousArray<Node> nodes_;
303  ContiguousArray<LeafOutputType> leaf_vector_;
304  // Map nid to the start and end index in leaf_vector_
305  // We could use std::pair, but it is not POD, so easier to use two vectors
306  // here
307  ContiguousArray<std::size_t> leaf_vector_begin_;
308  ContiguousArray<std::size_t> leaf_vector_end_;
309  ContiguousArray<uint32_t> matching_categories_;
310  ContiguousArray<std::size_t> matching_categories_offset_;
311 
312  template <typename WriterType, typename X, typename Y>
313  friend void DumpModelAsJSON(WriterType& writer, const ModelImpl<X, Y>& model);
314  template <typename WriterType, typename X, typename Y>
315  friend void DumpTreeAsJSON(WriterType& writer, const Tree<X, Y>& tree);
316 
317  // allocate a new node
318  inline int AllocNode();
319 
320  // utility functions used for serialization, internal use only
321  template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
322  inline void
323  SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
324  CompositeArrayHandler composite_array_handler);
325  template <typename ScalarHandler, typename ArrayHandler>
326  inline void
327  DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
328 
329  public:
331  int num_nodes;
333  inline void Init();
338  inline void AddChilds(int nid);
339 
344  inline std::vector<unsigned> GetCategoricalFeatures() const;
345 
351  inline int LeftChild(int nid) const {
352  return nodes_.at(nid).cleft_;
353  }
358  inline int RightChild(int nid) const {
359  return nodes_.at(nid).cright_;
360  }
365  inline int DefaultChild(int nid) const {
366  return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
367  }
372  inline uint32_t SplitIndex(int nid) const {
373  return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
374  }
379  inline bool DefaultLeft(int nid) const {
380  return (nodes_.at(nid).sindex_ >> 31U) != 0;
381  }
386  inline bool IsLeaf(int nid) const {
387  return nodes_.at(nid).cleft_ == -1;
388  }
393  inline LeafOutputType LeafValue(int nid) const {
394  return (nodes_.at(nid).info_).leaf_value;
395  }
400  inline std::vector<LeafOutputType> LeafVector(int nid) const {
401  const std::size_t offset_begin = leaf_vector_begin_.at(nid);
402  const std::size_t offset_end = leaf_vector_end_.at(nid);
403  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
404  // Return empty vector, to indicate the lack of leaf vector
405  return std::vector<LeafOutputType>();
406  }
407  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
408  &leaf_vector_[offset_end]);
409  // Use unsafe access here, since we may need to take the address of one past the last
410  // element, to follow with the range semantic of std::vector<>.
411  }
416  inline bool HasLeafVector(int nid) const {
417  return leaf_vector_begin_.at(nid) != leaf_vector_end_.at(nid);
418  }
423  inline ThresholdType Threshold(int nid) const {
424  return (nodes_.at(nid).info_).threshold;
425  }
430  inline Operator ComparisonOp(int nid) const {
431  return nodes_.at(nid).cmp_;
432  }
441  inline std::vector<uint32_t> MatchingCategories(int nid) const {
442  const std::size_t offset_begin = matching_categories_offset_.at(nid);
443  const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
444  if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
445  // Return empty vector, to indicate the lack of any matching categories
446  // The node might be a numerical split
447  return std::vector<uint32_t>();
448  }
449  return std::vector<uint32_t>(&matching_categories_[offset_begin],
450  &matching_categories_[offset_end]);
451  // Use unsafe access here, since we may need to take the address of one past the last
452  // element, to follow with the range semantic of std::vector<>.
453  }
458  inline SplitFeatureType SplitType(int nid) const {
459  return nodes_.at(nid).split_type_;
460  }
465  inline bool HasDataCount(int nid) const {
466  return nodes_.at(nid).data_count_present_;
467  }
472  inline uint64_t DataCount(int nid) const {
473  return nodes_.at(nid).data_count_;
474  }
475 
480  inline bool HasSumHess(int nid) const {
481  return nodes_.at(nid).sum_hess_present_;
482  }
487  inline double SumHess(int nid) const {
488  return nodes_.at(nid).sum_hess_;
489  }
494  inline bool HasGain(int nid) const {
495  return nodes_.at(nid).gain_present_;
496  }
501  inline double Gain(int nid) const {
502  return nodes_.at(nid).gain_;
503  }
509  inline bool CategoriesListRightChild(int nid) const {
510  return nodes_.at(nid).categories_list_right_child_;
511  }
512 
523  inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
524  bool default_left, Operator cmp);
537  inline void SetCategoricalSplit(int nid, unsigned split_index, bool default_left,
538  const std::vector<uint32_t>& categories_list,
539  bool categories_list_right_child);
545  inline void SetLeaf(int nid, LeafOutputType value);
551  inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
557  inline void SetSumHess(int nid, double sum_hess) {
558  Node& node = nodes_.at(nid);
559  node.sum_hess_ = sum_hess;
560  node.sum_hess_present_ = true;
561  }
567  inline void SetDataCount(int nid, uint64_t data_count) {
568  Node& node = nodes_.at(nid);
569  node.data_count_ = data_count;
570  node.data_count_present_ = true;
571  }
577  inline void SetGain(int nid, double gain) {
578  Node& node = nodes_.at(nid);
579  node.gain_ = gain;
580  node.gain_present_ = true;
581  }
582 };
583 
584 struct ModelParam {
606  char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
622  float ratio_c;
629  float global_bias;
632  ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
633  std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH * sizeof(char));
634  std::strncpy(pred_transform, "identity", sizeof(pred_transform));
635  }
636  ~ModelParam() = default;
637  ModelParam(const ModelParam&) = default;
638  ModelParam& operator=(const ModelParam&) = default;
639  ModelParam(ModelParam&&) = default;
640  ModelParam& operator=(ModelParam&&) = default;
641 
642  template<typename Container>
643  inline std::vector<std::pair<std::string, std::string>>
644  InitAllowUnknown(const Container &kwargs);
645  inline std::map<std::string, std::string> __DICT__() const;
646 };
647 
648 static_assert(std::is_standard_layout<ModelParam>::value,
649  "ModelParam must be in the standard layout");
650 
651 inline void InitParamAndCheck(ModelParam* param,
652  const std::vector<std::pair<std::string, std::string>>& cfg);
653 
655 class Model {
656  public:
658  Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
659  patch_ver_(TREELITE_VER_PATCH) {}
660  virtual ~Model() = default;
661  Model(const Model&) = delete;
662  Model& operator=(const Model&) = delete;
663  Model(Model&&) = default;
664  Model& operator=(Model&&) = default;
665 
666  template <typename ThresholdType, typename LeafOutputType>
667  inline static std::unique_ptr<Model> Create();
668  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
669  inline TypeInfo GetThresholdType() const {
670  return threshold_type_;
671  }
672  inline TypeInfo GetLeafOutputType() const {
673  return leaf_output_type_;
674  }
675  template <typename Func>
676  inline auto Dispatch(Func func);
677  template <typename Func>
678  inline auto Dispatch(Func func) const;
679 
680  virtual std::size_t GetNumTree() const = 0;
681  virtual void SetTreeLimit(std::size_t limit) = 0;
682  virtual void DumpAsJSON(std::ostream& fo, bool pretty_print) const = 0;
683 
684  inline std::string DumpAsJSON(bool pretty_print) const {
685  std::ostringstream oss;
686  DumpAsJSON(oss, pretty_print);
687  return oss.str();
688  }
689 
690  /* In-memory serialization, zero-copy */
691  std::vector<PyBufferFrame> GetPyBuffer();
692  static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
693 
694  /* Serialization to a file stream */
695  void SerializeToFile(FILE* dest_fp);
696  static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
697 
711 
712  private:
713  int major_ver_, minor_ver_, patch_ver_;
714  TypeInfo threshold_type_;
715  TypeInfo leaf_output_type_;
716  // Internal functions for serialization
717  virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
718  virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
719  virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
720  std::vector<PyBufferFrame>::iterator end) = 0;
721  virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
722  template <typename HeaderPrimitiveFieldHandlerFunc>
723  inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
724  template <typename HeaderPrimitiveFieldHandlerFunc>
725  inline static void DeserializeTemplate(
726  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
727  TypeInfo& threshold_type, TypeInfo& leaf_output_type);
728 };
729 
730 template <typename ThresholdType, typename LeafOutputType>
731 class ModelImpl : public Model {
732  public:
734  std::vector<Tree<ThresholdType, LeafOutputType>> trees;
735 
737  ModelImpl() = default;
738  ~ModelImpl() override = default;
739  ModelImpl(const ModelImpl&) = delete;
740  ModelImpl& operator=(const ModelImpl&) = delete;
741  ModelImpl(ModelImpl&&) noexcept = default;
742  ModelImpl& operator=(ModelImpl&&) noexcept = default;
743 
744  void DumpAsJSON(std::ostream& fo, bool pretty_print) const override;
745  inline std::size_t GetNumTree() const override {
746  return trees.size();
747  }
748  void SetTreeLimit(std::size_t limit) override {
749  return trees.resize(limit);
750  }
751 
752  inline void GetPyBuffer(std::vector<PyBufferFrame>* dest) override;
753  inline void SerializeToFileImpl(FILE* dest_fp) override;
754  inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
755  std::vector<PyBufferFrame>::iterator end) override;
756  inline void DeserializeFromFileImpl(FILE* src_fp) override;
757 
758  private:
759  template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
760  typename TreeHandlerFunc>
761  inline void SerializeTemplate(
762  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
763  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
764  TreeHandlerFunc tree_handler);
765  template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
766  inline void DeserializeTemplate(
767  size_t num_tree,
768  HeaderFieldHandlerFunc header_field_handler,
769  TreeHandlerFunc tree_handler);
770 };
771 
772 } // namespace treelite
773 
774 #include "tree_impl.h"
775 
776 #endif // TREELITE_TREE_H_
ModelParam param
extra parameters
Definition: tree.h:710
SplitFeatureType split_type_
feature split type
Definition: tree.h:251
Operator ComparisonOp(int nid) const
get comparison operator
Definition: tree.h:430
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
Definition: tree.h:263
SplitFeatureType
feature split type
Definition: base.h:22
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
Definition: tree.h:238
bool HasDataCount(int nid) const
test whether this node has data count
Definition: tree.h:465
bool HasGain(int nid) const
test whether this node has gain value
Definition: tree.h:494
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Definition: tree.h:183
Group of parameters that are dependent on the choice of the task type.
Definition: tree.h:172
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
Definition: tree.h:257
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
Definition: tree.h:400
TaskType
Enum type representing the task type.
Definition: tree.h:98
bool average_tree_output
whether to average tree outputs
Definition: tree.h:706
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
Definition: tree.h:614
int DefaultChild(int nid) const
index of the node&#39;s "default" child, used when feature is missing
Definition: tree.h:365
bool data_count_present_
whether data_count_ field is present
Definition: tree.h:259
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
Definition: tree.h:198
tree node
Definition: tree.h:216
int32_t cleft_
pointer to left and right children
Definition: tree.h:226
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
Definition: tree.h:557
in-memory representation of a decision tree
Definition: tree.h:213
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
Definition: tree.h:245
unsigned int num_class
The number of classes in the target label.
Definition: tree.h:191
float global_bias
global bias of the model
Definition: tree.h:629
double gain_
change in loss that is attributed to a particular split
Definition: tree.h:249
TaskType task_type
Task type.
Definition: tree.h:704
uint32_t SplitIndex(int nid) const
feature index of the node&#39;s split condition
Definition: tree.h:372
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
store either leaf value or decision threshold
Definition: tree.h:221
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
Definition: tree.h:734
double SumHess(int nid) const
get hessian sum
Definition: tree.h:487
void SetGain(int nid, double gain)
set the gain value of the node
Definition: tree.h:577
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
Definition: tree.h:567
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
Definition: tree.h:509
SplitFeatureType SplitType(int nid) const
get feature split type
Definition: tree.h:458
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Definition: tree.h:393
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Definition: tree.h:622
Model()
disable copy; use default move
Definition: tree.h:658
TaskParam task_param
Group of parameters that are specific to the particular task type.
Definition: tree.h:708
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Definition: tree.h:441
uint64_t DataCount(int nid) const
get data count
Definition: tree.h:472
double Gain(int nid) const
get gain value
Definition: tree.h:501
int RightChild(int nid) const
index of the node&#39;s right child
Definition: tree.h:358
OutputType output_type
The type of output from each leaf node.
Definition: tree.h:175
bool sum_hess_present_
whether sum_hess_ field is present
Definition: tree.h:261
thin wrapper for tree ensemble model
Definition: tree.h:655
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
Definition: tree.h:379
bool HasSumHess(int nid) const
test whether this node has hessian sum
Definition: tree.h:480
bool IsLeaf(int nid) const
whether the node is leaf node
Definition: tree.h:386
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:702
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
Definition: tree.h:231
ThresholdType Threshold(int nid) const
get threshold of the node
Definition: tree.h:423
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Definition: tree.h:416
Info info_
storage for leaf value or decision threshold
Definition: tree.h:233
Operator
comparison operators
Definition: base.h:26