7 #ifndef TREELITE_DETAIL_TREE_H_
8 #define TREELITE_DETAIL_TREE_H_
12 #include <treelite/version.h>
26 #include <unordered_map>
32 template <
typename ThresholdType,
typename LeafOutputType>
38 tree.node_type_ = node_type_.
Clone();
39 tree.cleft_ = cleft_.
Clone();
40 tree.cright_ = cright_.
Clone();
41 tree.split_index_ = split_index_.
Clone();
42 tree.default_left_ = default_left_.
Clone();
43 tree.leaf_value_ = leaf_value_.
Clone();
44 tree.threshold_ = threshold_.
Clone();
45 tree.cmp_ = cmp_.
Clone();
46 tree.category_list_right_child_ = category_list_right_child_.
Clone();
48 tree.leaf_vector_ = leaf_vector_.
Clone();
49 tree.leaf_vector_begin_ = leaf_vector_begin_.
Clone();
50 tree.leaf_vector_end_ = leaf_vector_end_.
Clone();
51 tree.category_list_ = category_list_.
Clone();
52 tree.category_list_begin_ = category_list_begin_.
Clone();
53 tree.category_list_end_ = category_list_end_.
Clone();
55 tree.data_count_ = data_count_.
Clone();
56 tree.sum_hess_ = sum_hess_.
Clone();
57 tree.gain_ = gain_.
Clone();
58 tree.data_count_present_ = data_count_present_.
Clone();
59 tree.sum_hess_present_ = sum_hess_present_.
Clone();
60 tree.gain_present_ = gain_present_.
Clone();
62 tree.has_categorical_split_ = has_categorical_split_;
63 tree.num_opt_field_per_tree_ = num_opt_field_per_tree_;
64 tree.num_opt_field_per_node_ = num_opt_field_per_node_;
69 template <
typename ThresholdType,
typename LeafOutputType>
74 split_index_.PushBack(-1);
75 default_left_.PushBack(
false);
76 leaf_value_.PushBack(
static_cast<LeafOutputType
>(0));
77 threshold_.PushBack(
static_cast<ThresholdType
>(0));
79 category_list_right_child_.PushBack(
false);
81 leaf_vector_begin_.PushBack(leaf_vector_.Size());
82 leaf_vector_end_.PushBack(leaf_vector_.Size());
83 category_list_begin_.PushBack(category_list_.Size());
84 category_list_end_.PushBack(category_list_.Size());
87 if (!data_count_present_.Empty()) {
88 data_count_.PushBack(0);
89 data_count_present_.PushBack(
false);
91 if (!sum_hess_present_.Empty()) {
92 sum_hess_.PushBack(0);
93 sum_hess_present_.PushBack(
false);
95 if (!gain_present_.Empty()) {
97 gain_present_.PushBack(
false);
103 template <
typename ThresholdType,
typename LeafOutputType>
108 split_index_.Clear();
109 default_left_.Clear();
113 category_list_right_child_.Clear();
116 has_categorical_split_ =
false;
118 leaf_vector_.Clear();
119 leaf_vector_begin_.Clear();
120 leaf_vector_end_.Clear();
121 category_list_.Clear();
122 category_list_begin_.Clear();
123 category_list_end_.Clear();
126 template <
typename ThresholdType,
typename LeafOutputType>
128 int nid, std::int32_t split_index, ThresholdType threshold,
bool default_left,
Operator cmp) {
129 split_index_.at(nid) = split_index;
130 threshold_.at(nid) = threshold;
131 default_left_.at(nid) = default_left;
134 category_list_right_child_.at(nid) =
false;
137 template <
typename ThresholdType,
typename LeafOutputType>
139 std::int32_t split_index,
bool default_left, std::vector<std::uint32_t>
const& category_list,
140 bool category_list_right_child) {
141 TREELITE_CHECK(CategoryList(nid).empty()) <<
"Cannot set categorical test twice for same node";
143 std::size_t
const begin = category_list_.Size();
144 std::size_t
const end = begin + category_list.size();
145 category_list_.Extend(category_list);
146 category_list_begin_.at(nid) = begin;
147 category_list_end_.at(nid) = end;
149 split_index_.at(nid) = split_index;
150 default_left_.at(nid) = default_left;
152 category_list_right_child_.at(nid) = category_list_right_child;
154 has_categorical_split_ =
true;
157 template <
typename ThresholdType,
typename LeafOutputType>
159 leaf_value_.at(nid) = value;
161 cright_.at(nid) = -1;
165 template <
typename ThresholdType,
typename LeafOutputType>
167 int nid, std::vector<LeafOutputType>
const& node_leaf_vector) {
168 TREELITE_CHECK(!HasLeafVector(nid)) <<
"Cannot set leaf vector twice for same node";
169 std::size_t begin = leaf_vector_.Size();
170 std::size_t end = begin + node_leaf_vector.size();
171 leaf_vector_.Extend(node_leaf_vector);
172 leaf_vector_begin_.at(nid) = begin;
173 leaf_vector_end_.at(nid) = end;
175 split_index_.at(nid) = -1;
177 cright_.at(nid) = -1;
181 template <
typename ThresholdType,
typename LeafOutputType>
183 if (sum_hess_present_.Empty()) {
184 sum_hess_present_.Resize(num_nodes,
false);
185 sum_hess_.Resize(num_nodes);
187 sum_hess_.at(nid) = sum_hess;
188 sum_hess_present_.at(nid) =
true;
191 template <
typename ThresholdType,
typename LeafOutputType>
193 if (data_count_present_.Empty()) {
194 data_count_present_.Resize(num_nodes,
false);
195 data_count_.Resize(num_nodes);
197 data_count_.at(nid) = data_count;
198 data_count_present_.at(nid) =
true;
201 template <
typename ThresholdType,
typename LeafOutputType>
203 if (gain_present_.Empty()) {
204 gain_present_.Resize(num_nodes,
false);
205 gain_.Resize(num_nodes);
207 gain_.at(nid) = gain;
208 gain_present_.at(nid) =
true;
211 template <
typename ThresholdType,
typename LeafOutputType>
213 std::unique_ptr<Model> model = std::make_unique<Model>();
219 std::unique_ptr<Model> model = std::make_unique<Model>();
221 <<
"threshold_type must be either float32 or float64";
223 <<
"threshold_type must be identical to leaf_output_type";
225 model->variant_ = SetModelPresetVariant<0>(target_variant_index);
ContiguousArray Clone() const
Definition: contiguous_array.h:79
Typed portion of the model class.
Definition: tree.h:399
static std::unique_ptr< Model > Create()
Definition: tree.h:212
in-memory representation of a decision tree
Definition: tree.h:79
int AllocNode()
Allocate a new node and return the node's ID.
Definition: tree.h:70
void SetLeafVector(int nid, std::vector< LeafOutputType > const &leaf_vector)
Set the leaf vector of the node; useful for multi-class random forest classifier.
Definition: tree.h:166
void SetNumericalTest(int nid, std::int32_t split_index, ThresholdType threshold, bool default_left, Operator cmp)
Create a numerical test.
Definition: tree.h:127
void SetSumHess(int nid, double sum_hess)
Set the hessian sum of the node.
Definition: tree.h:182
Tree< ThresholdType, LeafOutputType > Clone() const
Definition: tree.h:33
void Init()
Initialize the tree with a single root node.
Definition: tree.h:104
void SetDataCount(int nid, std::uint64_t data_count)
Set the data count of the node.
Definition: tree.h:192
void SetCategoricalTest(int nid, std::int32_t split_index, bool default_left, std::vector< std::uint32_t > const &category_list, bool category_list_right_child)
Create a categorical test.
Definition: tree.h:138
void SetGain(int nid, double gain)
Set the gain value of the node.
Definition: tree.h:202
std::int32_t num_nodes
Number of nodes.
Definition: tree.h:150
void SetLeaf(int nid, LeafOutputType value)
Set the leaf value of the node.
Definition: tree.h:158
Exception class used throughout the Treelite codebase.
logging facility for Treelite
#define TREELITE_CHECK(x)
Definition: logging.h:70
Definition: contiguous_array.h:14
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
Operator
Type of comparison operators used in numerical test nodes.
Definition: operator.h:17