7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 11 #include <treelite/version.h> 20 #include <type_traits> 27 #define __TREELITE_STR(x) #x 28 #define _TREELITE_STR(x) __TREELITE_STR(x) 30 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 36 template <
typename ThresholdType,
typename LeafOutputType>
48 static_assert(std::is_pod<PyBufferFrame>::value,
"PyBufferFrame must be a POD type");
61 inline void UseForeignBuffer(
void* prealloc_buf, std::size_t size);
63 inline const T* Data()
const;
65 inline const T* End()
const;
67 inline const T& Back()
const;
68 inline std::size_t Size()
const;
69 inline bool Empty()
const;
70 inline void Reserve(std::size_t newsize);
71 inline void Resize(std::size_t newsize);
72 inline void Resize(std::size_t newsize, T t);
74 inline void PushBack(T t);
75 inline void Extend(
const std::vector<T>& other);
77 inline T& operator[](std::size_t idx);
78 inline const T& operator[](std::size_t idx)
const;
80 inline T& at(std::size_t idx);
81 inline const T& at(std::size_t idx)
const;
83 inline T& at(
int idx);
84 inline const T& at(
int idx)
const;
85 static_assert(std::is_pod<T>::value,
"T must be POD");
90 std::size_t capacity_;
127 kMultiClfGrovePerClass = 1,
143 kMultiClfProbDistLeaf = 2,
160 kMultiClfCategLeaf = 3
163 inline std::string TaskTypeToString(
TaskType type) {
165 case TaskType::kBinaryClfRegr:
return "BinaryClfRegr";
166 case TaskType::kMultiClfGrovePerClass:
return "MultiClfGrovePerClass";
167 case TaskType::kMultiClfProbDistLeaf:
return "MultiClfProbDistLeaf";
168 case TaskType::kMultiClfCategLeaf:
return "MultiClfCategLeaf";
175 enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
203 inline std::string OutputTypeToString(TaskParam::OutputType type) {
205 case TaskParam::OutputType::kFloat:
return "float";
206 case TaskParam::OutputType::kInt:
return "int";
211 static_assert(std::is_pod<TaskParam>::value,
"TaskParameter must be POD type");
214 template <
typename ThresholdType,
typename LeafOutputType>
224 LeafOutputType leaf_value;
225 ThresholdType threshold;
268 bool categories_list_right_child_;
274 inline int RightChild()
const {
277 inline bool DefaultLeft()
const {
279 return (sindex_ >> 31U) != 0;
281 inline int DefaultChild()
const {
283 return ((sindex_ >> 31U) != 0) ? cleft_ : cright_;
285 inline std::uint32_t SplitIndex()
const {
287 return (sindex_ & ((1U << 31U) - 1U));
289 inline bool IsLeaf()
const {
292 inline LeafOutputType LeafValue()
const {
293 return info_.leaf_value;
295 inline ThresholdType Threshold()
const {
296 return info_.threshold;
298 inline Operator ComparisonOp()
const {
304 inline bool HasDataCount()
const {
305 return data_count_present_;
307 inline std::uint64_t DataCount()
const {
310 inline bool HasSumHess()
const {
311 return sum_hess_present_;
313 inline double SumHess()
const {
316 inline bool HasGain()
const {
317 return gain_present_;
319 inline double Gain()
const {
322 inline bool CategoriesListRightChild()
const {
323 return categories_list_right_child_;
327 static_assert(std::is_pod<Node>::value,
"Node must be a POD type");
328 static_assert(std::is_same<ThresholdType, float>::value
329 || std::is_same<ThresholdType, double>::value,
330 "ThresholdType must be either float32 or float64");
331 static_assert(std::is_same<LeafOutputType, uint32_t>::value
332 || std::is_same<LeafOutputType, float>::value
333 || std::is_same<LeafOutputType, double>::value,
334 "LeafOutputType must be one of uint32_t, float32 or float64");
335 static_assert(std::is_same<ThresholdType, LeafOutputType>::value
336 || std::is_same<LeafOutputType, uint32_t>::value,
337 "Unsupported combination of ThresholdType and LeafOutputType");
338 static_assert((std::is_same<ThresholdType, float>::value &&
sizeof(
Node) == 48)
339 || (std::is_same<ThresholdType, double>::value &&
sizeof(
Node) == 56),
340 "Node size incorrect");
345 Tree& operator=(
const Tree&) =
delete;
347 Tree& operator=(
Tree&&) noexcept =
default;
351 inline const char* GetFormatStringForNode();
352 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
353 inline void SerializeToFile(FILE* dest_fp);
354 inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
355 std::vector<PyBufferFrame>::iterator end);
356 inline void DeserializeFromFile(FILE* src_fp);
369 bool has_categorical_split_{
false};
371 template <
typename WriterType,
typename X,
typename Y>
372 friend void DumpModelAsJSON(WriterType& writer,
const ModelImpl<X, Y>& model);
373 template <
typename WriterType,
typename X,
typename Y>
374 friend void DumpTreeAsJSON(WriterType& writer,
const Tree<X, Y>& tree);
377 inline int AllocNode();
380 template <
typename ScalarHandler,
typename PrimitiveArrayHandler,
typename CompositeArrayHandler>
382 SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
383 CompositeArrayHandler composite_array_handler);
384 template <
typename ScalarHandler,
typename ArrayHandler>
386 DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
399 inline void AddChilds(
int nid);
407 return nodes_[nid].LeftChild();
414 return nodes_[nid].RightChild();
421 return nodes_[nid].DefaultChild();
428 return nodes_[nid].SplitIndex();
435 return nodes_[nid].DefaultLeft();
442 return nodes_[nid].IsLeaf();
449 return nodes_[nid].LeafValue();
455 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
456 const std::size_t offset_begin = leaf_vector_begin_[nid];
457 const std::size_t offset_end = leaf_vector_end_[nid];
458 if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
460 return std::vector<LeafOutputType>();
462 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
463 &leaf_vector_[offset_end]);
472 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
479 return nodes_[nid].Threshold();
486 return nodes_[nid].ComparisonOp();
497 const std::size_t offset_begin = matching_categories_offset_[nid];
498 const std::size_t offset_end = matching_categories_offset_[nid + 1];
499 if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
502 return std::vector<std::uint32_t>();
504 return std::vector<std::uint32_t>(&matching_categories_[offset_begin],
505 &matching_categories_[offset_end]);
514 return nodes_[nid].SplitType();
521 return nodes_[nid].HasDataCount();
528 return nodes_[nid].DataCount();
536 return nodes_[nid].HasSumHess();
543 return nodes_[nid].SumHess();
550 return nodes_[nid].HasGain();
556 inline double Gain(
int nid)
const {
557 return nodes_[nid].Gain();
565 return nodes_[nid].CategoriesListRightChild();
572 return has_categorical_split_;
585 inline void SetNumericalSplit(
int nid,
unsigned split_index, ThresholdType threshold,
599 inline void SetCategoricalSplit(
int nid,
unsigned split_index,
bool default_left,
600 const std::vector<uint32_t>& categories_list,
601 bool categories_list_right_child);
607 inline void SetLeaf(
int nid, LeafOutputType value);
613 inline void SetLeafVector(
int nid,
const std::vector<LeafOutputType>& leaf_vector);
620 Node& node = nodes_.at(nid);
621 node.sum_hess_ = sum_hess;
622 node.sum_hess_present_ =
true;
630 Node& node = nodes_.at(nid);
631 node.data_count_ = data_count;
632 node.data_count_present_ =
true;
640 Node& node = nodes_.at(nid);
642 node.gain_present_ =
true;
668 char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
694 ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
695 std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH *
sizeof(
char));
696 std::strncpy(pred_transform,
"identity",
sizeof(pred_transform));
704 template<
typename Container>
705 inline std::vector<std::pair<std::string, std::string>>
706 InitAllowUnknown(
const Container &kwargs);
707 inline std::map<std::string, std::string> __DICT__()
const;
710 static_assert(std::is_standard_layout<ModelParam>::value,
711 "ModelParam must be in the standard layout");
713 inline void InitParamAndCheck(
ModelParam* param,
714 const std::vector<std::pair<std::string, std::string>>& cfg);
720 Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
721 patch_ver_(TREELITE_VER_PATCH) {}
722 virtual ~
Model() =
default;
728 template <
typename ThresholdType,
typename LeafOutputType>
729 inline static std::unique_ptr<Model> Create();
730 inline static std::unique_ptr<Model> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type);
731 inline TypeInfo GetThresholdType()
const {
732 return threshold_type_;
734 inline TypeInfo GetLeafOutputType()
const {
735 return leaf_output_type_;
737 template <
typename Func>
738 inline auto Dispatch(Func func);
739 template <
typename Func>
740 inline auto Dispatch(Func func)
const;
742 virtual std::size_t GetNumTree()
const = 0;
743 virtual void SetTreeLimit(std::size_t limit) = 0;
744 virtual void DumpAsJSON(std::ostream& fo,
bool pretty_print)
const = 0;
746 inline std::string DumpAsJSON(
bool pretty_print)
const {
747 std::ostringstream oss;
748 DumpAsJSON(oss, pretty_print);
753 std::vector<PyBufferFrame> GetPyBuffer();
754 static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
757 void SerializeToFile(FILE* dest_fp);
758 static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
775 int major_ver_, minor_ver_, patch_ver_;
779 virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
780 virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
781 virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
782 std::vector<PyBufferFrame>::iterator end) = 0;
783 virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
784 template <
typename HeaderPrimitiveFieldHandlerFunc>
785 inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
786 template <
typename HeaderPrimitiveFieldHandlerFunc>
787 inline static void DeserializeTemplate(
788 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
792 template <
typename ThresholdType,
typename LeafOutputType>
796 std::vector<Tree<ThresholdType, LeafOutputType>>
trees;
806 void DumpAsJSON(std::ostream& fo,
bool pretty_print) const override;
807 inline std::
size_t GetNumTree()
const override {
810 void SetTreeLimit(std::size_t limit)
override {
811 return trees.resize(limit);
814 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest)
override;
815 inline void SerializeToFileImpl(FILE* dest_fp)
override;
816 inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
817 std::vector<PyBufferFrame>::iterator end)
override;
818 inline void DeserializeFromFileImpl(FILE* src_fp)
override;
821 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
822 typename TreeHandlerFunc>
823 inline void SerializeTemplate(
824 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
825 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
826 TreeHandlerFunc tree_handler);
827 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc>
828 inline void DeserializeTemplate(
830 HeaderFieldHandlerFunc header_field_handler,
831 TreeHandlerFunc tree_handler);
838 #endif // TREELITE_TREE_H_ ModelParam param
extra parameters
SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
SplitFeatureType
feature split type
std::int32_t cleft_
pointer to left and right children
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
TaskType
Enum type representing the task type.
bool average_tree_output
whether to average tree outputs
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool data_count_present_
whether data_count_ field is present
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
TaskType task_type
Task type.
TypeInfo
Types used by thresholds and leaf outputs.
store either leaf value or decision threshold
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
std::uint32_t SplitIndex(int nid) const
feature index of the node's split condition
double SumHess(int nid) const
get hessian sum
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
int num_nodes
number of nodes
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Model()
disable copy; use default move
std::uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
int LeftChild(int nid) const
Getters.
TaskParam task_param
Group of parameters that are specific to the particular task type.
defines configuration macros of Treelite
std::uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool sum_hess_present_
whether sum_hess_ field is present
thin wrapper for tree ensemble model
std::uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
int LeftChild() const
Getters.
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Info info_
storage for leaf value or decision threshold
Operator
comparison operators