treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
9 
11 #include <treelite/enum/operator.h>
14 #include <treelite/enum/typeinfo.h>
15 #include <treelite/logging.h>
17 #include <treelite/version.h>
18 
19 #include <algorithm>
20 #include <cstddef>
21 #include <cstdint>
22 #include <cstdio>
23 #include <iostream>
24 #include <limits>
25 #include <map>
26 #include <memory>
27 #include <sstream>
28 #include <string>
29 #include <type_traits>
30 #include <utility>
31 #include <variant>
32 #include <vector>
33 
34 /* Indicator that certain functions should be visible from a library (Windows only) */
35 #if defined(_MSC_VER) || defined(_WIN32)
36 #define TREELITE_DLL_EXPORT __declspec(dllexport)
37 #else
38 #define TREELITE_DLL_EXPORT
39 #endif
40 
41 namespace treelite {
42 
43 template <typename ThresholdType, typename LeafOutputType>
44 class ModelPreset;
45 
46 }
47 
49 
50 template <typename MixIn>
51 class Serializer;
52 template <typename MixIn>
54 
55 } // namespace treelite::detail::serializer
56 
58 
59 template <typename ThresholdType, typename LeafOutputType>
61  ModelPreset<ThresholdType, LeafOutputType>&, std::uint64_t, std::string const&);
62 template <typename ThresholdType, typename LeafOutputType>
64  ModelPreset<ThresholdType, LeafOutputType>&, std::uint64_t, std::string const&, PyBufferFrame);
65 
66 } // namespace treelite::detail::field_accessor
67 
68 namespace treelite {
69 
70 // Used for returning version triple from a Model object
71 struct Version {
72  std::int32_t major_ver;
73  std::int32_t minor_ver;
74  std::int32_t patch_ver;
75 };
76 
78 template <typename ThresholdType, typename LeafOutputType>
79 class Tree {
80  public:
81  static_assert(std::is_same_v<ThresholdType, float> || std::is_same_v<ThresholdType, double>,
82  "ThresholdType must be either float32 or float64");
83  static_assert(std::is_same_v<LeafOutputType, float> || std::is_same_v<LeafOutputType, double>,
84  "LeafOutputType must be one of uint32_t, float32 or float64");
85  static_assert(std::is_same_v<ThresholdType, LeafOutputType>,
86  "Unsupported combination of ThresholdType and LeafOutputType");
87 
88  Tree() = default;
89  ~Tree() = default;
90  Tree(Tree const&) = delete;
91  Tree& operator=(Tree const&) = delete;
92  Tree(Tree&&) noexcept = default;
93  Tree& operator=(Tree&&) noexcept = default;
94 
95  inline Tree<ThresholdType, LeafOutputType> Clone() const;
96 
97  private:
98  ContiguousArray<TreeNodeType> node_type_;
99  ContiguousArray<std::int32_t> cleft_;
100  ContiguousArray<std::int32_t> cright_;
101  ContiguousArray<std::int32_t> split_index_;
102  ContiguousArray<bool> default_left_;
103  ContiguousArray<LeafOutputType> leaf_value_;
104  ContiguousArray<ThresholdType> threshold_;
106  ContiguousArray<bool> category_list_right_child_;
107 
108  // Leaf vector
109  ContiguousArray<LeafOutputType> leaf_vector_;
110  ContiguousArray<std::uint64_t> leaf_vector_begin_;
111  ContiguousArray<std::uint64_t> leaf_vector_end_;
112 
113  // Category list
114  ContiguousArray<std::uint32_t> category_list_;
115  ContiguousArray<std::uint64_t> category_list_begin_;
116  ContiguousArray<std::uint64_t> category_list_end_;
117 
118  // Node statistics
119  ContiguousArray<std::uint64_t> data_count_;
120  ContiguousArray<double> sum_hess_;
121  ContiguousArray<double> gain_;
122  ContiguousArray<bool> data_count_present_;
123  ContiguousArray<bool> sum_hess_present_;
124  ContiguousArray<bool> gain_present_;
125 
126  bool has_categorical_split_{false};
127 
128  /* Note: the following member fields shall be re-computed at serialization time */
129 
130  // Number of optional fields in the extension slots
131  std::int32_t num_opt_field_per_tree_{0};
132  std::int32_t num_opt_field_per_node_{0};
133 
134  template <typename WriterType, typename X, typename Y>
135  friend void DumpTreeAsJSON(WriterType& writer, Tree<X, Y> const& tree);
136 
137  template <typename MixIn>
139  template <typename MixIn>
141  template <typename X, typename Y>
143  ModelPreset<X, Y>&, std::uint64_t, std::string const&);
144  template <typename X, typename Y>
146  ModelPreset<X, Y>&, std::uint64_t, std::string const&, PyBufferFrame);
147 
148  public:
150  std::int32_t num_nodes{0};
152  inline void Init();
154  inline int AllocNode();
155 
161  inline int LeftChild(int nid) const {
162  return cleft_[nid];
163  }
168  inline int RightChild(int nid) const {
169  return cright_[nid];
170  }
175  inline int DefaultChild(int nid) const {
176  return default_left_[nid] ? cleft_[nid] : cright_[nid];
177  }
182  inline std::int32_t SplitIndex(int nid) const {
183  return split_index_[nid];
184  }
189  inline bool DefaultLeft(int nid) const {
190  return default_left_[nid];
191  }
196  inline bool IsLeaf(int nid) const {
197  return cleft_[nid] == -1;
198  }
203  inline LeafOutputType LeafValue(int nid) const {
204  return leaf_value_[nid];
205  }
210  inline std::vector<LeafOutputType> LeafVector(int nid) const {
211  std::size_t const offset_begin = leaf_vector_begin_[nid];
212  std::size_t const offset_end = leaf_vector_end_[nid];
213  if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
214  // Return empty vector, to indicate the lack of leaf vector
215  return std::vector<LeafOutputType>();
216  }
217  return std::vector<LeafOutputType>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
218  // Use unsafe access here, since we may need to take the address of one past the last
219  // element, to follow with the range semantic of std::vector<>.
220  }
225  inline bool HasLeafVector(int nid) const {
226  return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
227  }
232  inline ThresholdType Threshold(int nid) const {
233  return threshold_[nid];
234  }
239  inline Operator ComparisonOp(int nid) const {
240  return cmp_[nid];
241  }
251  inline std::vector<std::uint32_t> CategoryList(int nid) const {
252  std::size_t const offset_begin = category_list_begin_[nid];
253  std::size_t const offset_end = category_list_end_[nid];
254  if (offset_begin >= category_list_.Size() || offset_end > category_list_.Size()) {
255  // Return empty vector, to indicate the lack of any category list
256  // The node might be a numerical split
257  return {};
258  }
259  return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
260  // Use unsafe access here, since we may need to take the address of one past the last
261  // element, to follow with the range semantic of std::vector<>.
262  }
267  inline TreeNodeType NodeType(int nid) const {
268  return node_type_[nid];
269  }
274  inline bool HasDataCount(int nid) const {
275  return !data_count_present_.Empty() && data_count_present_[nid];
276  }
281  inline std::uint64_t DataCount(int nid) const {
282  return data_count_[nid];
283  }
284 
289  inline bool HasSumHess(int nid) const {
290  return !sum_hess_present_.Empty() && sum_hess_present_[nid];
291  }
296  inline double SumHess(int nid) const {
297  return sum_hess_[nid];
298  }
303  inline bool HasGain(int nid) const {
304  return !gain_present_.Empty() && gain_present_[nid];
305  }
310  inline double Gain(int nid) const {
311  return gain_[nid];
312  }
318  inline bool CategoryListRightChild(int nid) const {
319  return category_list_right_child_[nid];
320  }
321 
325  inline bool HasCategoricalSplit() const {
326  return has_categorical_split_;
327  }
328 
336  inline void SetChildren(int nid, int left_child, int right_child) {
337  cleft_[nid] = left_child;
338  cright_[nid] = right_child;
339  }
349  inline void SetNumericalTest(
350  int nid, std::int32_t split_index, ThresholdType threshold, bool default_left, Operator cmp);
363  inline void SetCategoricalTest(int nid, std::int32_t split_index, bool default_left,
364  std::vector<std::uint32_t> const& category_list, bool category_list_right_child);
370  inline void SetLeaf(int nid, LeafOutputType value);
376  inline void SetLeafVector(int nid, std::vector<LeafOutputType> const& leaf_vector);
382  inline void SetSumHess(int nid, double sum_hess);
388  inline void SetDataCount(int nid, std::uint64_t data_count);
394  inline void SetGain(int nid, double gain);
395 };
396 
398 template <typename ThresholdT, typename LeafOutputT>
399 class ModelPreset {
400  public:
402  std::vector<Tree<ThresholdT, LeafOutputT>> trees;
403 
404  using threshold_type = ThresholdT;
405  using leaf_output_type = LeafOutputT;
406 
408  ModelPreset() = default;
409  ~ModelPreset() = default;
410  ModelPreset(ModelPreset const&) = delete;
411  ModelPreset& operator=(ModelPreset const&) = delete;
412  ModelPreset(ModelPreset&&) noexcept = default;
413  ModelPreset& operator=(ModelPreset&&) noexcept = default;
414 
415  inline TypeInfo GetThresholdType() const {
416  return TypeInfoFromType<ThresholdT>();
417  }
418  inline TypeInfo GetLeafOutputType() const {
419  return TypeInfoFromType<LeafOutputT>();
420  }
421  inline std::size_t GetNumTree() const {
422  return trees.size();
423  }
424  void SetTreeLimit(std::size_t limit) {
425  return trees.resize(limit);
426  }
427 };
428 
429 using ModelPresetVariant = std::variant<ModelPreset<float, float>, ModelPreset<double, double>>;
430 
431 template <int variant_index>
432 ModelPresetVariant SetModelPresetVariant(int target_variant_index) {
433  ModelPresetVariant result;
434  if constexpr (variant_index != std::variant_size_v<ModelPresetVariant>) {
435  if (variant_index == target_variant_index) {
436  using ModelPresetT = std::variant_alternative_t<variant_index, ModelPresetVariant>;
437  result = ModelPresetT();
438  } else {
439  result = SetModelPresetVariant<variant_index + 1>(target_variant_index);
440  }
441  }
442  return result;
443 }
444 
446 class Model {
447  public:
450  : major_ver_(TREELITE_VER_MAJOR),
451  minor_ver_(TREELITE_VER_MINOR),
452  patch_ver_(TREELITE_VER_PATCH) {}
453  virtual ~Model() = default;
454  Model(Model const&) = delete;
455  Model& operator=(Model const&) = delete;
456  Model(Model&&) = default;
457  Model& operator=(Model&&) = default;
458 
460 
461  template <typename ThresholdType, typename LeafOutputType>
462  inline static std::unique_ptr<Model> Create();
463  inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
464  inline TypeInfo GetThresholdType() const {
465  return std::visit([](auto&& inner) { return inner.GetThresholdType(); }, variant_);
466  }
467  inline TypeInfo GetLeafOutputType() const {
468  return std::visit([](auto&& inner) { return inner.GetLeafOutputType(); }, variant_);
469  }
470 
471  inline std::size_t GetNumTree() const {
472  return std::visit([](auto&& inner) { return inner.GetNumTree(); }, variant_);
473  }
474  inline void SetTreeLimit(std::size_t limit) {
475  std::visit([=](auto&& inner) { return inner.SetTreeLimit(limit); }, variant_);
476  }
477  void DumpAsJSON(std::ostream& fo, bool pretty_print) const;
478 
479  inline std::string DumpAsJSON(bool pretty_print) const {
480  std::ostringstream oss;
481  DumpAsJSON(oss, pretty_print);
482  return oss.str();
483  }
484 
485  /* Compatibility Matrix:
486  +------------------+----------+----------+----------------+-----------+
487  | | To: =3.9 | To: =4.0 | To: >=4.1,<5.0 | To: >=5.0 |
488  +------------------+----------+----------+----------------+-----------+
489  | From: =3.9 | Yes | Yes | Yes | No |
490  | From: =4.0 | No | Yes | Yes | Yes |
491  | From: >=4.1,<5.0 | No | Yes | Yes | Yes |
492  | From: >=5.0 | No | No | No | Yes |
493  +------------------+----------+----------+----------------+-----------+ */
494 
495  /* In-memory serialization, zero-copy */
496  TREELITE_DLL_EXPORT std::vector<PyBufferFrame> SerializeToPyBuffer();
497  TREELITE_DLL_EXPORT static std::unique_ptr<Model> DeserializeFromPyBuffer(
498  std::vector<PyBufferFrame> const& frames);
499 
500  /* Serialization to a file stream */
501  void SerializeToStream(std::ostream& os);
502  static std::unique_ptr<Model> DeserializeFromStream(std::istream& is);
504  inline Version GetVersion() const {
505  return {major_ver_, minor_ver_, patch_ver_};
506  }
507 
509  PyBufferFrame GetHeaderField(std::string const& name);
511  PyBufferFrame GetTreeField(std::uint64_t tree_id, std::string const& name);
513  void SetHeaderField(std::string const& name, PyBufferFrame frame);
515  void SetTreeField(std::uint64_t tree_id, std::string const& name, PyBufferFrame frame);
516 
521  std::int32_t num_feature{0};
525  bool average_tree_output{false};
526 
527  /* Task parameters */
528  std::int32_t num_target;
531  /* Per-tree metadata */
534  /* Other model parameters */
535  std::string postprocessor;
536  float sigmoid_alpha{1.0f};
537  float ratio_c{1.0f};
539  std::string attributes;
540 
541  private:
542  /* Note: the following member fields shall be re-computed at serialization time */
543  // Number of trees
544  std::uint64_t num_tree_{0};
545  // Number of optional fields in the extension slot
546  std::int32_t num_opt_field_per_model_{0};
547  // Which Treelite version produced this model
548  std::int32_t major_ver_;
549  std::int32_t minor_ver_;
550  std::int32_t patch_ver_;
551  // Type parameters
552  TypeInfo threshold_type_{TypeInfo::kInvalid};
553  TypeInfo leaf_output_type_{TypeInfo::kInvalid};
554 
555  template <typename MixIn>
557  template <typename MixIn>
559 };
560 
567 std::unique_ptr<Model> ConcatenateModelObjects(std::vector<Model const*> const& objs);
568 
569 } // namespace treelite
570 
571 #include <treelite/detail/tree.h>
572 
573 #endif // TREELITE_TREE_H_
Definition: contiguous_array.h:17
bool Empty() const
Definition: contiguous_array.h:143
std::size_t Size() const
Definition: contiguous_array.h:138
Typed portion of the model class.
Definition: tree.h:399
ThresholdT threshold_type
Definition: tree.h:404
ModelPreset()=default
disable copy; use default move
std::size_t GetNumTree() const
Definition: tree.h:421
void SetTreeLimit(std::size_t limit)
Definition: tree.h:424
TypeInfo GetLeafOutputType() const
Definition: tree.h:418
std::vector< Tree< ThresholdT, LeafOutputT > > trees
member trees
Definition: tree.h:402
ModelPreset(ModelPreset &&) noexcept=default
LeafOutputT leaf_output_type
Definition: tree.h:405
ModelPreset(ModelPreset const &)=delete
ModelPreset & operator=(ModelPreset const &)=delete
TypeInfo GetThresholdType() const
Definition: tree.h:415
Model class for tree ensemble model.
Definition: tree.h:446
static std::unique_ptr< Model > Create()
Definition: tree.h:212
Model(Model &&)=default
float ratio_c
Definition: tree.h:537
void DumpAsJSON(std::ostream &fo, bool pretty_print) const
std::int32_t num_target
Definition: tree.h:528
std::int32_t num_feature
Number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:521
void SetTreeLimit(std::size_t limit)
Definition: tree.h:474
virtual ~Model()=default
TREELITE_DLL_EXPORT std::vector< PyBufferFrame > SerializeToPyBuffer()
TaskType task_type
Task type.
Definition: tree.h:523
ContiguousArray< std::int32_t > class_id
Definition: tree.h:533
Model & operator=(Model &&)=default
std::string postprocessor
Definition: tree.h:535
void SerializeToStream(std::ostream &os)
std::string attributes
Definition: tree.h:539
void SetHeaderField(std::string const &name, PyBufferFrame frame)
Set a field in the header.
TypeInfo GetThresholdType() const
Definition: tree.h:464
Model(Model const &)=delete
ContiguousArray< std::int32_t > num_class
Definition: tree.h:529
PyBufferFrame GetTreeField(std::uint64_t tree_id, std::string const &name)
Get a field in a tree.
std::string DumpAsJSON(bool pretty_print) const
Definition: tree.h:479
Model()
disable copy; use default move
Definition: tree.h:449
bool average_tree_output
whether to average tree outputs
Definition: tree.h:525
Model & operator=(Model const &)=delete
void SetTreeField(std::uint64_t tree_id, std::string const &name, PyBufferFrame frame)
Set a field in a tree.
std::size_t GetNumTree() const
Definition: tree.h:471
static TREELITE_DLL_EXPORT std::unique_ptr< Model > DeserializeFromPyBuffer(std::vector< PyBufferFrame > const &frames)
ContiguousArray< std::int32_t > target_id
Definition: tree.h:532
ModelPresetVariant variant_
Definition: tree.h:459
PyBufferFrame GetHeaderField(std::string const &name)
Get a field in the header.
TypeInfo GetLeafOutputType() const
Definition: tree.h:467
Version GetVersion() const
Return the Treelite version that produced this Model object.
Definition: tree.h:504
ContiguousArray< double > base_scores
Definition: tree.h:538
ContiguousArray< std::int32_t > leaf_vector_shape
Definition: tree.h:530
float sigmoid_alpha
Definition: tree.h:536
static std::unique_ptr< Model > DeserializeFromStream(std::istream &is)
in-memory representation of a decision tree
Definition: tree.h:79
int AllocNode()
Allocate a new node and return the node's ID.
Definition: tree.h:70
std::int32_t SplitIndex(int nid) const
Feature index of the node's split condition.
Definition: tree.h:182
Tree(Tree &&) noexcept=default
ThresholdType Threshold(int nid) const
Get threshold of the node.
Definition: tree.h:232
void SetLeafVector(int nid, std::vector< LeafOutputType > const &leaf_vector)
Set the leaf vector of the node; useful for multi-class random forest classifier.
Definition: tree.h:166
bool HasSumHess(int nid) const
Test whether this node has hessian sum.
Definition: tree.h:289
void SetNumericalTest(int nid, std::int32_t split_index, ThresholdType threshold, bool default_left, Operator cmp)
Create a numerical test.
Definition: tree.h:127
bool HasDataCount(int nid) const
Test whether this node has data count.
Definition: tree.h:274
LeafOutputType LeafValue(int nid) const
Get leaf value of the leaf node.
Definition: tree.h:203
double Gain(int nid) const
Get gain value.
Definition: tree.h:310
~Tree()=default
Tree & operator=(Tree const &)=delete
void SetSumHess(int nid, double sum_hess)
Set the hessian sum of the node.
Definition: tree.h:182
Tree< ThresholdType, LeafOutputType > Clone() const
Definition: tree.h:33
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition: tree.h:325
void SetChildren(int nid, int left_child, int right_child)
Identify two child nodes of the node.
Definition: tree.h:336
Tree(Tree const &)=delete
bool CategoryListRightChild(int nid) const
Test whether the list given by CategoryList(nid) is associated with the right child node or the left ...
Definition: tree.h:318
void Init()
Initialize the tree with a single root node.
Definition: tree.h:104
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier
Definition: tree.h:210
Operator ComparisonOp(int nid) const
Get comparison operator.
Definition: tree.h:239
Tree()=default
int LeftChild(int nid) const
Index of the node's left child.
Definition: tree.h:161
int RightChild(int nid) const
Index of the node's right child.
Definition: tree.h:168
friend void DumpTreeAsJSON(WriterType &writer, Tree< X, Y > const &tree)
void SetDataCount(int nid, std::uint64_t data_count)
Set the data count of the node.
Definition: tree.h:192
void SetCategoricalTest(int nid, std::int32_t split_index, bool default_left, std::vector< std::uint32_t > const &category_list, bool category_list_right_child)
Create a categorical test.
Definition: tree.h:138
bool HasLeafVector(int nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition: tree.h:225
void SetGain(int nid, double gain)
Set the gain value of the node.
Definition: tree.h:202
TreeNodeType NodeType(int nid) const
Get the type of a node.
Definition: tree.h:267
std::vector< std::uint32_t > CategoryList(int nid) const
Get list of all categories belonging to the left/right child node. See the category_list_right_child_...
Definition: tree.h:251
bool DefaultLeft(int nid) const
Whether to use the left child node, when the feature in the split condition is missing.
Definition: tree.h:189
double SumHess(int nid) const
Get hessian sum.
Definition: tree.h:296
bool HasGain(int nid) const
Test whether this node has gain value.
Definition: tree.h:303
std::uint64_t DataCount(int nid) const
Get data count.
Definition: tree.h:281
bool IsLeaf(int nid) const
Whether the node is leaf node.
Definition: tree.h:196
int DefaultChild(int nid) const
Index of the node's "default" child, used when feature is missing.
Definition: tree.h:175
std::int32_t num_nodes
Number of nodes.
Definition: tree.h:150
void SetLeaf(int nid, LeafOutputType value)
Set the leaf value of the node.
Definition: tree.h:158
A simple array container, with owned or non-owned (externally allocated) buffer.
Implementation for treelite/tree.h.
logging facility for Treelite
void SetTreeFieldImpl(ModelPreset< ThresholdType, LeafOutputType > &, std::uint64_t, std::string const &, PyBufferFrame)
PyBufferFrame GetTreeFieldImpl(ModelPreset< ThresholdType, LeafOutputType > &, std::uint64_t, std::string const &)
Definition: serializer.h:28
Definition: contiguous_array.h:14
TreelitePyBufferFrame PyBufferFrame
Definition: pybuffer_frame.h:18
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
TaskType
Enum type representing the task type.
Definition: task_type.h:19
ModelPresetVariant SetModelPresetVariant(int target_variant_index)
Definition: tree.h:432
std::unique_ptr< Model > ConcatenateModelObjects(std::vector< Model const * > const &objs)
Concatenate multiple model objects into a single model object by copying all member trees into the de...
Operator
Type of comparison operators used in numerical test nodes.
Definition: operator.h:17
TreeNodeType
Tree node type.
Definition: tree_node_type.h:17
std::variant< ModelPreset< float, float >, ModelPreset< double, double > > ModelPresetVariant
Definition: tree.h:429
Define enum type Operator.
Data structure to enable zero-copy exchange in Python.
Definition: tree.h:71
std::int32_t minor_ver
Definition: tree.h:73
std::int32_t major_ver
Definition: tree.h:72
std::int32_t patch_ver
Definition: tree.h:74
Define enum type TaskType.
#define TREELITE_DLL_EXPORT
Definition: tree.h:38
Define enum type NodeType.
Defines enum type TypeInfo.