7 #ifndef TREELITE_TREE_H_
8 #define TREELITE_TREE_H_
17 #include <treelite/version.h>
29 #include <type_traits>
35 #if defined(_MSC_VER) || defined(_WIN32)
36 #define TREELITE_DLL_EXPORT __declspec(dllexport)
38 #define TREELITE_DLL_EXPORT
43 template <
typename ThresholdType,
typename LeafOutputType>
50 template <
typename MixIn>
52 template <
typename MixIn>
59 template <
typename ThresholdType,
typename LeafOutputType>
62 template <
typename ThresholdType,
typename LeafOutputType>
78 template <
typename ThresholdType,
typename LeafOutputType>
81 static_assert(std::is_same_v<ThresholdType, float> || std::is_same_v<ThresholdType, double>,
82 "ThresholdType must be either float32 or float64");
83 static_assert(std::is_same_v<LeafOutputType, float> || std::is_same_v<LeafOutputType, double>,
84 "LeafOutputType must be one of uint32_t, float32 or float64");
85 static_assert(std::is_same_v<ThresholdType, LeafOutputType>,
86 "Unsupported combination of ThresholdType and LeafOutputType");
95 inline
Tree<ThresholdType, LeafOutputType>
Clone() const;
126 bool has_categorical_split_{
false};
131 std::int32_t num_opt_field_per_tree_{0};
132 std::int32_t num_opt_field_per_node_{0};
134 template <
typename WriterType,
typename X,
typename Y>
137 template <
typename MixIn>
139 template <
typename MixIn>
141 template <
typename X,
typename Y>
144 template <
typename X,
typename Y>
176 return default_left_[nid] ? cleft_[nid] : cright_[nid];
183 return split_index_[nid];
190 return default_left_[nid];
197 return cleft_[nid] == -1;
204 return leaf_value_[nid];
210 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
211 std::size_t
const offset_begin = leaf_vector_begin_[nid];
212 std::size_t
const offset_end = leaf_vector_end_[nid];
213 if (offset_begin >= leaf_vector_.
Size() || offset_end > leaf_vector_.
Size()) {
215 return std::vector<LeafOutputType>();
217 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin], &leaf_vector_[offset_end]);
226 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
233 return threshold_[nid];
252 std::size_t
const offset_begin = category_list_begin_[nid];
253 std::size_t
const offset_end = category_list_end_[nid];
254 if (offset_begin >= category_list_.
Size() || offset_end > category_list_.
Size()) {
259 return std::vector<std::uint32_t>(&category_list_[offset_begin], &category_list_[offset_end]);
268 return node_type_[nid];
275 return !data_count_present_.
Empty() && data_count_present_[nid];
282 return data_count_[nid];
290 return !sum_hess_present_.
Empty() && sum_hess_present_[nid];
297 return sum_hess_[nid];
304 return !gain_present_.
Empty() && gain_present_[nid];
310 inline double Gain(
int nid)
const {
319 return category_list_right_child_[nid];
326 return has_categorical_split_;
336 inline void SetChildren(
int nid,
int left_child,
int right_child) {
337 cleft_[nid] = left_child;
338 cright_[nid] = right_child;
350 int nid, std::int32_t split_index, ThresholdType threshold,
bool default_left,
Operator cmp);
364 std::vector<std::uint32_t>
const& category_list,
bool category_list_right_child);
370 inline void SetLeaf(
int nid, LeafOutputType value);
376 inline void SetLeafVector(
int nid, std::vector<LeafOutputType>
const& leaf_vector);
382 inline void SetSumHess(
int nid,
double sum_hess);
388 inline void SetDataCount(
int nid, std::uint64_t data_count);
394 inline void SetGain(
int nid,
double gain);
398 template <
typename ThresholdT,
typename LeafOutputT>
402 std::vector<Tree<ThresholdT, LeafOutputT>>
trees;
416 return TypeInfoFromType<ThresholdT>();
419 return TypeInfoFromType<LeafOutputT>();
425 return trees.resize(limit);
431 template <
int variant_index>
434 if constexpr (variant_index != std::variant_size_v<ModelPresetVariant>) {
435 if (variant_index == target_variant_index) {
436 using ModelPresetT = std::variant_alternative_t<variant_index, ModelPresetVariant>;
437 result = ModelPresetT();
439 result = SetModelPresetVariant<variant_index + 1>(target_variant_index);
450 : major_ver_(TREELITE_VER_MAJOR),
451 minor_ver_(TREELITE_VER_MINOR),
452 patch_ver_(TREELITE_VER_PATCH) {}
461 template <
typename ThresholdType,
typename LeafOutputType>
462 inline static std::unique_ptr<Model>
Create();
465 return std::visit([](
auto&& inner) {
return inner.GetThresholdType(); },
variant_);
468 return std::visit([](
auto&& inner) {
return inner.GetLeafOutputType(); },
variant_);
472 return std::visit([](
auto&& inner) {
return inner.GetNumTree(); },
variant_);
475 std::visit([=](
auto&& inner) {
return inner.SetTreeLimit(limit); },
variant_);
480 std::ostringstream oss;
498 std::vector<PyBufferFrame>
const& frames);
505 return {major_ver_, minor_ver_, patch_ver_};
544 std::uint64_t num_tree_{0};
546 std::int32_t num_opt_field_per_model_{0};
548 std::int32_t major_ver_;
549 std::int32_t minor_ver_;
550 std::int32_t patch_ver_;
555 template <
typename MixIn>
557 template <
typename MixIn>
Definition: contiguous_array.h:17
bool Empty() const
Definition: contiguous_array.h:143
std::size_t Size() const
Definition: contiguous_array.h:138
Typed portion of the model class.
Definition: tree.h:399
ThresholdT threshold_type
Definition: tree.h:404
ModelPreset()=default
disable copy; use default move
std::size_t GetNumTree() const
Definition: tree.h:421
void SetTreeLimit(std::size_t limit)
Definition: tree.h:424
TypeInfo GetLeafOutputType() const
Definition: tree.h:418
std::vector< Tree< ThresholdT, LeafOutputT > > trees
member trees
Definition: tree.h:402
ModelPreset(ModelPreset &&) noexcept=default
LeafOutputT leaf_output_type
Definition: tree.h:405
ModelPreset(ModelPreset const &)=delete
ModelPreset & operator=(ModelPreset const &)=delete
TypeInfo GetThresholdType() const
Definition: tree.h:415
Model class for tree ensemble model.
Definition: tree.h:446
static std::unique_ptr< Model > Create()
Definition: tree.h:212
float ratio_c
Definition: tree.h:537
void DumpAsJSON(std::ostream &fo, bool pretty_print) const
std::int32_t num_target
Definition: tree.h:528
std::int32_t num_feature
Number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Definition: tree.h:521
void SetTreeLimit(std::size_t limit)
Definition: tree.h:474
TREELITE_DLL_EXPORT std::vector< PyBufferFrame > SerializeToPyBuffer()
TaskType task_type
Task type.
Definition: tree.h:523
ContiguousArray< std::int32_t > class_id
Definition: tree.h:533
Model & operator=(Model &&)=default
std::string postprocessor
Definition: tree.h:535
void SerializeToStream(std::ostream &os)
std::string attributes
Definition: tree.h:539
void SetHeaderField(std::string const &name, PyBufferFrame frame)
Set a field in the header.
TypeInfo GetThresholdType() const
Definition: tree.h:464
Model(Model const &)=delete
ContiguousArray< std::int32_t > num_class
Definition: tree.h:529
PyBufferFrame GetTreeField(std::uint64_t tree_id, std::string const &name)
Get a field in a tree.
std::string DumpAsJSON(bool pretty_print) const
Definition: tree.h:479
Model()
disable copy; use default move
Definition: tree.h:449
bool average_tree_output
whether to average tree outputs
Definition: tree.h:525
Model & operator=(Model const &)=delete
void SetTreeField(std::uint64_t tree_id, std::string const &name, PyBufferFrame frame)
Set a field in a tree.
std::size_t GetNumTree() const
Definition: tree.h:471
static TREELITE_DLL_EXPORT std::unique_ptr< Model > DeserializeFromPyBuffer(std::vector< PyBufferFrame > const &frames)
ContiguousArray< std::int32_t > target_id
Definition: tree.h:532
ModelPresetVariant variant_
Definition: tree.h:459
PyBufferFrame GetHeaderField(std::string const &name)
Get a field in the header.
TypeInfo GetLeafOutputType() const
Definition: tree.h:467
Version GetVersion() const
Return the Treelite version that produced this Model object.
Definition: tree.h:504
ContiguousArray< double > base_scores
Definition: tree.h:538
ContiguousArray< std::int32_t > leaf_vector_shape
Definition: tree.h:530
float sigmoid_alpha
Definition: tree.h:536
static std::unique_ptr< Model > DeserializeFromStream(std::istream &is)
in-memory representation of a decision tree
Definition: tree.h:79
int AllocNode()
Allocate a new node and return the node's ID.
Definition: tree.h:70
std::int32_t SplitIndex(int nid) const
Feature index of the node's split condition.
Definition: tree.h:182
Tree(Tree &&) noexcept=default
ThresholdType Threshold(int nid) const
Get threshold of the node.
Definition: tree.h:232
void SetLeafVector(int nid, std::vector< LeafOutputType > const &leaf_vector)
Set the leaf vector of the node; useful for multi-class random forest classifier.
Definition: tree.h:166
bool HasSumHess(int nid) const
Test whether this node has hessian sum.
Definition: tree.h:289
void SetNumericalTest(int nid, std::int32_t split_index, ThresholdType threshold, bool default_left, Operator cmp)
Create a numerical test.
Definition: tree.h:127
bool HasDataCount(int nid) const
Test whether this node has data count.
Definition: tree.h:274
LeafOutputType LeafValue(int nid) const
Get leaf value of the leaf node.
Definition: tree.h:203
double Gain(int nid) const
Get gain value.
Definition: tree.h:310
Tree & operator=(Tree const &)=delete
void SetSumHess(int nid, double sum_hess)
Set the hessian sum of the node.
Definition: tree.h:182
Tree< ThresholdType, LeafOutputType > Clone() const
Definition: tree.h:33
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
Definition: tree.h:325
void SetChildren(int nid, int left_child, int right_child)
Identify two child nodes of the node.
Definition: tree.h:336
Tree(Tree const &)=delete
bool CategoryListRightChild(int nid) const
Test whether the list given by CategoryList(nid) is associated with the right child node or the left ...
Definition: tree.h:318
void Init()
Initialize the tree with a single root node.
Definition: tree.h:104
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier
Definition: tree.h:210
Operator ComparisonOp(int nid) const
Get comparison operator.
Definition: tree.h:239
int LeftChild(int nid) const
Index of the node's left child.
Definition: tree.h:161
int RightChild(int nid) const
Index of the node's right child.
Definition: tree.h:168
friend void DumpTreeAsJSON(WriterType &writer, Tree< X, Y > const &tree)
void SetDataCount(int nid, std::uint64_t data_count)
Set the data count of the node.
Definition: tree.h:192
void SetCategoricalTest(int nid, std::int32_t split_index, bool default_left, std::vector< std::uint32_t > const &category_list, bool category_list_right_child)
Create a categorical test.
Definition: tree.h:138
bool HasLeafVector(int nid) const
Tests whether the leaf node has a non-empty leaf vector.
Definition: tree.h:225
void SetGain(int nid, double gain)
Set the gain value of the node.
Definition: tree.h:202
TreeNodeType NodeType(int nid) const
Get the type of a node.
Definition: tree.h:267
std::vector< std::uint32_t > CategoryList(int nid) const
Get list of all categories belonging to the left/right child node. See the category_list_right_child_...
Definition: tree.h:251
bool DefaultLeft(int nid) const
Whether to use the left child node, when the feature in the split condition is missing.
Definition: tree.h:189
double SumHess(int nid) const
Get hessian sum.
Definition: tree.h:296
bool HasGain(int nid) const
Test whether this node has gain value.
Definition: tree.h:303
std::uint64_t DataCount(int nid) const
Get data count.
Definition: tree.h:281
bool IsLeaf(int nid) const
Whether the node is leaf node.
Definition: tree.h:196
int DefaultChild(int nid) const
Index of the node's "default" child, used when feature is missing.
Definition: tree.h:175
std::int32_t num_nodes
Number of nodes.
Definition: tree.h:150
void SetLeaf(int nid, LeafOutputType value)
Set the leaf value of the node.
Definition: tree.h:158
A simple array container, with owned or non-owned (externally allocated) buffer.
Implementation for treelite/tree.h.
logging facility for Treelite
void SetTreeFieldImpl(ModelPreset< ThresholdType, LeafOutputType > &, std::uint64_t, std::string const &, PyBufferFrame)
PyBufferFrame GetTreeFieldImpl(ModelPreset< ThresholdType, LeafOutputType > &, std::uint64_t, std::string const &)
Definition: serializer.h:28
Definition: contiguous_array.h:14
TreelitePyBufferFrame PyBufferFrame
Definition: pybuffer_frame.h:18
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
TaskType
Enum type representing the task type.
Definition: task_type.h:19
ModelPresetVariant SetModelPresetVariant(int target_variant_index)
Definition: tree.h:432
std::unique_ptr< Model > ConcatenateModelObjects(std::vector< Model const * > const &objs)
Concatenate multiple model objects into a single model object by copying all member trees into the de...
Operator
Type of comparison operators used in numerical test nodes.
Definition: operator.h:17
TreeNodeType
Tree node type.
Definition: tree_node_type.h:17
std::variant< ModelPreset< float, float >, ModelPreset< double, double > > ModelPresetVariant
Definition: tree.h:429
Define enum type Operator.
Data structure to enable zero-copy exchange in Python.
std::int32_t minor_ver
Definition: tree.h:73
std::int32_t major_ver
Definition: tree.h:72
std::int32_t patch_ver
Definition: tree.h:74
Define enum type TaskType.
#define TREELITE_DLL_EXPORT
Definition: tree.h:38
Define enum type NodeType.
Defines enum type TypeInfo.