treelite
tree.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_DETAIL_TREE_H_
8 #define TREELITE_DETAIL_TREE_H_
9 
10 #include <treelite/error.h>
11 #include <treelite/logging.h>
12 #include <treelite/version.h>
13 
14 #include <algorithm>
15 #include <cstddef>
16 #include <cstdint>
17 #include <iomanip>
18 #include <iostream>
19 #include <limits>
20 #include <map>
21 #include <memory>
22 #include <sstream>
23 #include <stdexcept>
24 #include <string>
25 #include <typeinfo>
26 #include <unordered_map>
27 #include <utility>
28 #include <vector>
29 
30 namespace treelite {
31 
32 template <typename ThresholdType, typename LeafOutputType>
35 
36  tree.num_nodes = num_nodes;
37 
38  tree.node_type_ = node_type_.Clone();
39  tree.cleft_ = cleft_.Clone();
40  tree.cright_ = cright_.Clone();
41  tree.split_index_ = split_index_.Clone();
42  tree.default_left_ = default_left_.Clone();
43  tree.leaf_value_ = leaf_value_.Clone();
44  tree.threshold_ = threshold_.Clone();
45  tree.cmp_ = cmp_.Clone();
46  tree.category_list_right_child_ = category_list_right_child_.Clone();
47 
48  tree.leaf_vector_ = leaf_vector_.Clone();
49  tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
50  tree.leaf_vector_end_ = leaf_vector_end_.Clone();
51  tree.category_list_ = category_list_.Clone();
52  tree.category_list_begin_ = category_list_begin_.Clone();
53  tree.category_list_end_ = category_list_end_.Clone();
54 
55  tree.data_count_ = data_count_.Clone();
56  tree.sum_hess_ = sum_hess_.Clone();
57  tree.gain_ = gain_.Clone();
58  tree.data_count_present_ = data_count_present_.Clone();
59  tree.sum_hess_present_ = sum_hess_present_.Clone();
60  tree.gain_present_ = gain_present_.Clone();
61 
62  tree.has_categorical_split_ = has_categorical_split_;
63  tree.num_opt_field_per_tree_ = num_opt_field_per_tree_;
64  tree.num_opt_field_per_node_ = num_opt_field_per_node_;
65 
66  return tree;
67 }
68 
69 template <typename ThresholdType, typename LeafOutputType>
71  node_type_.PushBack(TreeNodeType::kLeafNode);
72  cleft_.PushBack(-1);
73  cright_.PushBack(-1);
74  split_index_.PushBack(-1);
75  default_left_.PushBack(false);
76  leaf_value_.PushBack(static_cast<LeafOutputType>(0));
77  threshold_.PushBack(static_cast<ThresholdType>(0));
78  cmp_.PushBack(Operator::kNone);
79  category_list_right_child_.PushBack(false);
80 
81  leaf_vector_begin_.PushBack(leaf_vector_.Size());
82  leaf_vector_end_.PushBack(leaf_vector_.Size());
83  category_list_begin_.PushBack(category_list_.Size());
84  category_list_end_.PushBack(category_list_.Size());
85 
86  // Invariant: node stat array must either be empty or have exact length of [num_nodes]
87  if (!data_count_present_.Empty()) {
88  data_count_.PushBack(0);
89  data_count_present_.PushBack(false);
90  }
91  if (!sum_hess_present_.Empty()) {
92  sum_hess_.PushBack(0);
93  sum_hess_present_.PushBack(false);
94  }
95  if (!gain_present_.Empty()) {
96  gain_.PushBack(0);
97  gain_present_.PushBack(false);
98  }
99 
100  return num_nodes++;
101 }
102 
103 template <typename ThresholdType, typename LeafOutputType>
105  node_type_.Clear();
106  cleft_.Clear();
107  cright_.Clear();
108  split_index_.Clear();
109  default_left_.Clear();
110  leaf_value_.Clear();
111  threshold_.Clear();
112  cmp_.Clear();
113  category_list_right_child_.Clear();
114 
115  num_nodes = 0;
116  has_categorical_split_ = false;
117 
118  leaf_vector_.Clear();
119  leaf_vector_begin_.Clear();
120  leaf_vector_end_.Clear();
121  category_list_.Clear();
122  category_list_begin_.Clear();
123  category_list_end_.Clear();
124 }
125 
126 template <typename ThresholdType, typename LeafOutputType>
128  int nid, std::int32_t split_index, ThresholdType threshold, bool default_left, Operator cmp) {
129  split_index_.at(nid) = split_index;
130  threshold_.at(nid) = threshold;
131  default_left_.at(nid) = default_left;
132  cmp_.at(nid) = cmp;
133  node_type_.at(nid) = TreeNodeType::kNumericalTestNode;
134  category_list_right_child_.at(nid) = false;
135 }
136 
137 template <typename ThresholdType, typename LeafOutputType>
139  std::int32_t split_index, bool default_left, std::vector<std::uint32_t> const& category_list,
140  bool category_list_right_child) {
141  TREELITE_CHECK(CategoryList(nid).empty()) << "Cannot set categorical test twice for same node";
142 
143  std::size_t const begin = category_list_.Size();
144  std::size_t const end = begin + category_list.size();
145  category_list_.Extend(category_list);
146  category_list_begin_.at(nid) = begin;
147  category_list_end_.at(nid) = end;
148 
149  split_index_.at(nid) = split_index;
150  default_left_.at(nid) = default_left;
151  node_type_.at(nid) = TreeNodeType::kCategoricalTestNode;
152  category_list_right_child_.at(nid) = category_list_right_child;
153 
154  has_categorical_split_ = true;
155 }
156 
157 template <typename ThresholdType, typename LeafOutputType>
158 inline void Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
159  leaf_value_.at(nid) = value;
160  cleft_.at(nid) = -1;
161  cright_.at(nid) = -1;
162  node_type_.at(nid) = TreeNodeType::kLeafNode;
163 }
164 
165 template <typename ThresholdType, typename LeafOutputType>
167  int nid, std::vector<LeafOutputType> const& node_leaf_vector) {
168  TREELITE_CHECK(!HasLeafVector(nid)) << "Cannot set leaf vector twice for same node";
169  std::size_t begin = leaf_vector_.Size();
170  std::size_t end = begin + node_leaf_vector.size();
171  leaf_vector_.Extend(node_leaf_vector);
172  leaf_vector_begin_.at(nid) = begin;
173  leaf_vector_end_.at(nid) = end;
174 
175  split_index_.at(nid) = -1;
176  cleft_.at(nid) = -1;
177  cright_.at(nid) = -1;
178  node_type_.at(nid) = TreeNodeType::kLeafNode;
179 }
180 
181 template <typename ThresholdType, typename LeafOutputType>
182 inline void Tree<ThresholdType, LeafOutputType>::SetSumHess(int nid, double sum_hess) {
183  if (sum_hess_present_.Empty()) {
184  sum_hess_present_.Resize(num_nodes, false);
185  sum_hess_.Resize(num_nodes);
186  }
187  sum_hess_.at(nid) = sum_hess;
188  sum_hess_present_.at(nid) = true;
189 }
190 
191 template <typename ThresholdType, typename LeafOutputType>
192 inline void Tree<ThresholdType, LeafOutputType>::SetDataCount(int nid, std::uint64_t data_count) {
193  if (data_count_present_.Empty()) {
194  data_count_present_.Resize(num_nodes, false);
195  data_count_.Resize(num_nodes);
196  }
197  data_count_.at(nid) = data_count;
198  data_count_present_.at(nid) = true;
199 }
200 
201 template <typename ThresholdType, typename LeafOutputType>
202 inline void Tree<ThresholdType, LeafOutputType>::SetGain(int nid, double gain) {
203  if (gain_present_.Empty()) {
204  gain_present_.Resize(num_nodes, false);
205  gain_.Resize(num_nodes);
206  }
207  gain_.at(nid) = gain;
208  gain_present_.at(nid) = true;
209 }
210 
211 template <typename ThresholdType, typename LeafOutputType>
212 inline std::unique_ptr<Model> Model::Create() {
213  std::unique_ptr<Model> model = std::make_unique<Model>();
214  model->variant_ = ModelPreset<ThresholdType, LeafOutputType>();
215  return model;
216 }
217 
218 inline std::unique_ptr<Model> Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
219  std::unique_ptr<Model> model = std::make_unique<Model>();
220  TREELITE_CHECK(threshold_type == TypeInfo::kFloat32 || threshold_type == TypeInfo::kFloat64)
221  << "threshold_type must be either float32 or float64";
222  TREELITE_CHECK(leaf_output_type == threshold_type)
223  << "threshold_type must be identical to leaf_output_type";
224  int const target_variant_index = threshold_type == TypeInfo::kFloat64;
225  model->variant_ = SetModelPresetVariant<0>(target_variant_index);
226  return model;
227 }
228 
229 } // namespace treelite
230 #endif // TREELITE_DETAIL_TREE_H_
ContiguousArray Clone() const
Definition: contiguous_array.h:79
Typed portion of the model class.
Definition: tree.h:399
static std::unique_ptr< Model > Create()
Definition: tree.h:212
in-memory representation of a decision tree
Definition: tree.h:79
int AllocNode()
Allocate a new node and return the node's ID.
Definition: tree.h:70
void SetLeafVector(int nid, std::vector< LeafOutputType > const &leaf_vector)
Set the leaf vector of the node; useful for multi-class random forest classifier.
Definition: tree.h:166
void SetNumericalTest(int nid, std::int32_t split_index, ThresholdType threshold, bool default_left, Operator cmp)
Create a numerical test.
Definition: tree.h:127
void SetSumHess(int nid, double sum_hess)
Set the hessian sum of the node.
Definition: tree.h:182
Tree< ThresholdType, LeafOutputType > Clone() const
Definition: tree.h:33
void Init()
Initialize the tree with a single root node.
Definition: tree.h:104
void SetDataCount(int nid, std::uint64_t data_count)
Set the data count of the node.
Definition: tree.h:192
void SetCategoricalTest(int nid, std::int32_t split_index, bool default_left, std::vector< std::uint32_t > const &category_list, bool category_list_right_child)
Create a categorical test.
Definition: tree.h:138
void SetGain(int nid, double gain)
Set the gain value of the node.
Definition: tree.h:202
std::int32_t num_nodes
Number of nodes.
Definition: tree.h:150
void SetLeaf(int nid, LeafOutputType value)
Set the leaf value of the node.
Definition: tree.h:158
Exception class used throughout the Treelite codebase.
logging facility for Treelite
#define TREELITE_CHECK(x)
Definition: logging.h:70
Definition: contiguous_array.h:14
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
Operator
Type of comparison operators used in numerical test nodes.
Definition: operator.h:17