7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 11 #include <treelite/version.h> 20 #include <type_traits> 27 #define __TREELITE_STR(x) #x 28 #define _TREELITE_STR(x) __TREELITE_STR(x) 30 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 33 #if defined(_MSC_VER) || defined(_WIN32) 34 #define TREELITE_DLL_EXPORT __declspec(dllexport) 36 #define TREELITE_DLL_EXPORT 43 template <
typename ThresholdType,
typename LeafOutputType>
55 static_assert(std::is_pod<PyBufferFrame>::value,
"PyBufferFrame must be a POD type");
68 inline void UseForeignBuffer(
void* prealloc_buf, std::size_t size);
70 inline const T* Data()
const;
72 inline const T* End()
const;
74 inline const T& Back()
const;
75 inline std::size_t Size()
const;
76 inline bool Empty()
const;
77 inline void Reserve(std::size_t newsize);
78 inline void Resize(std::size_t newsize);
79 inline void Resize(std::size_t newsize, T t);
81 inline void PushBack(T t);
82 inline void Extend(
const std::vector<T>& other);
84 inline T& operator[](std::size_t idx);
85 inline const T& operator[](std::size_t idx)
const;
87 inline T& at(std::size_t idx);
88 inline const T& at(std::size_t idx)
const;
90 inline T& at(
int idx);
91 inline const T& at(
int idx)
const;
92 static_assert(std::is_pod<T>::value,
"T must be POD");
97 std::size_t capacity_;
134 kMultiClfGrovePerClass = 1,
150 kMultiClfProbDistLeaf = 2,
167 kMultiClfCategLeaf = 3
170 inline std::string TaskTypeToString(
TaskType type) {
172 case TaskType::kBinaryClfRegr:
return "BinaryClfRegr";
173 case TaskType::kMultiClfGrovePerClass:
return "MultiClfGrovePerClass";
174 case TaskType::kMultiClfProbDistLeaf:
return "MultiClfProbDistLeaf";
175 case TaskType::kMultiClfCategLeaf:
return "MultiClfCategLeaf";
182 enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
210 inline std::string OutputTypeToString(TaskParam::OutputType type) {
212 case TaskParam::OutputType::kFloat:
return "float";
213 case TaskParam::OutputType::kInt:
return "int";
218 static_assert(std::is_pod<TaskParam>::value,
"TaskParameter must be POD type");
221 template <
typename ThresholdType,
typename LeafOutputType>
231 LeafOutputType leaf_value;
232 ThresholdType threshold;
275 bool categories_list_right_child_;
281 inline int RightChild()
const {
284 inline bool DefaultLeft()
const {
286 return (sindex_ >> 31U) != 0;
288 inline int DefaultChild()
const {
290 return ((sindex_ >> 31U) != 0) ? cleft_ : cright_;
292 inline std::uint32_t SplitIndex()
const {
294 return (sindex_ & ((1U << 31U) - 1U));
296 inline bool IsLeaf()
const {
299 inline LeafOutputType LeafValue()
const {
300 return info_.leaf_value;
302 inline ThresholdType Threshold()
const {
303 return info_.threshold;
305 inline Operator ComparisonOp()
const {
311 inline bool HasDataCount()
const {
312 return data_count_present_;
314 inline std::uint64_t DataCount()
const {
317 inline bool HasSumHess()
const {
318 return sum_hess_present_;
320 inline double SumHess()
const {
323 inline bool HasGain()
const {
324 return gain_present_;
326 inline double Gain()
const {
329 inline bool CategoriesListRightChild()
const {
330 return categories_list_right_child_;
334 static_assert(std::is_pod<Node>::value,
"Node must be a POD type");
335 static_assert(std::is_same<ThresholdType, float>::value
336 || std::is_same<ThresholdType, double>::value,
337 "ThresholdType must be either float32 or float64");
338 static_assert(std::is_same<LeafOutputType, uint32_t>::value
339 || std::is_same<LeafOutputType, float>::value
340 || std::is_same<LeafOutputType, double>::value,
341 "LeafOutputType must be one of uint32_t, float32 or float64");
342 static_assert(std::is_same<ThresholdType, LeafOutputType>::value
343 || std::is_same<LeafOutputType, uint32_t>::value,
344 "Unsupported combination of ThresholdType and LeafOutputType");
345 static_assert((std::is_same<ThresholdType, float>::value &&
sizeof(
Node) == 48)
346 || (std::is_same<ThresholdType, double>::value &&
sizeof(
Node) == 56),
347 "Node size incorrect");
349 explicit Tree(
bool use_opt_field =
true);
352 Tree& operator=(
const Tree&) =
delete;
354 Tree& operator=(
Tree&&) noexcept =
default;
358 inline const char* GetFormatStringForNode();
359 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
360 inline void SerializeToFile(FILE* dest_fp);
363 inline std::vector<PyBufferFrame>::iterator
364 InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it);
365 inline void DeserializeFromFile(FILE* src_fp);
378 bool has_categorical_split_{
false};
382 bool use_opt_field_{
false};
384 int32_t num_opt_field_per_tree_{0};
385 int32_t num_opt_field_per_node_{0};
387 template <
typename WriterType,
typename X,
typename Y>
388 friend void DumpModelAsJSON(WriterType& writer,
const ModelImpl<X, Y>& model);
389 template <
typename WriterType,
typename X,
typename Y>
390 friend void DumpTreeAsJSON(WriterType& writer,
const Tree<X, Y>& tree);
393 inline int AllocNode();
396 template <
typename ScalarHandler,
typename PrimitiveArrayHandler,
typename CompositeArrayHandler>
398 SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
399 CompositeArrayHandler composite_array_handler);
400 template <
typename ScalarHandler,
typename ArrayHandler,
typename SkipOptFieldHandlerFunc>
402 DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler,
403 SkipOptFieldHandlerFunc skip_opt_field_handler);
416 inline void AddChilds(
int nid);
424 return nodes_[nid].LeftChild();
431 return nodes_[nid].RightChild();
438 return nodes_[nid].DefaultChild();
445 return nodes_[nid].SplitIndex();
452 return nodes_[nid].DefaultLeft();
459 return nodes_[nid].IsLeaf();
466 return nodes_[nid].LeafValue();
472 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
473 const std::size_t offset_begin = leaf_vector_begin_[nid];
474 const std::size_t offset_end = leaf_vector_end_[nid];
475 if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
477 return std::vector<LeafOutputType>();
479 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
480 &leaf_vector_[offset_end]);
489 return leaf_vector_begin_[nid] != leaf_vector_end_[nid];
496 return nodes_[nid].Threshold();
503 return nodes_[nid].ComparisonOp();
514 const std::size_t offset_begin = matching_categories_offset_[nid];
515 const std::size_t offset_end = matching_categories_offset_[nid + 1];
516 if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
519 return std::vector<std::uint32_t>();
521 return std::vector<std::uint32_t>(&matching_categories_[offset_begin],
522 &matching_categories_[offset_end]);
531 return nodes_[nid].SplitType();
538 return nodes_[nid].HasDataCount();
545 return nodes_[nid].DataCount();
553 return nodes_[nid].HasSumHess();
560 return nodes_[nid].SumHess();
567 return nodes_[nid].HasGain();
573 inline double Gain(
int nid)
const {
574 return nodes_[nid].Gain();
582 return nodes_[nid].CategoriesListRightChild();
589 return has_categorical_split_;
602 inline void SetNumericalSplit(
int nid,
unsigned split_index, ThresholdType threshold,
616 inline void SetCategoricalSplit(
int nid,
unsigned split_index,
bool default_left,
617 const std::vector<uint32_t>& categories_list,
618 bool categories_list_right_child);
624 inline void SetLeaf(
int nid, LeafOutputType value);
630 inline void SetLeafVector(
int nid,
const std::vector<LeafOutputType>& leaf_vector);
637 Node& node = nodes_.at(nid);
638 node.sum_hess_ = sum_hess;
639 node.sum_hess_present_ =
true;
647 Node& node = nodes_.at(nid);
648 node.data_count_ = data_count;
649 node.data_count_present_ =
true;
657 Node& node = nodes_.at(nid);
659 node.gain_present_ =
true;
685 char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
711 ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
712 std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH *
sizeof(
char));
713 std::strncpy(pred_transform,
"identity",
sizeof(pred_transform));
721 template<
typename Container>
722 inline std::vector<std::pair<std::string, std::string>>
723 InitAllowUnknown(
const Container &kwargs);
724 inline std::map<std::string, std::string> __DICT__()
const;
727 static_assert(std::is_standard_layout<ModelParam>::value,
728 "ModelParam must be in the standard layout");
730 inline void InitParamAndCheck(
ModelParam* param,
731 const std::vector<std::pair<std::string, std::string>>& cfg);
737 Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
738 patch_ver_(TREELITE_VER_PATCH) {}
739 virtual ~
Model() =
default;
745 template <
typename ThresholdType,
typename LeafOutputType>
746 inline static std::unique_ptr<Model> Create();
747 inline static std::unique_ptr<Model> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type);
748 inline TypeInfo GetThresholdType()
const {
749 return threshold_type_;
751 inline TypeInfo GetLeafOutputType()
const {
752 return leaf_output_type_;
754 template <
typename Func>
755 inline auto Dispatch(Func func);
756 template <
typename Func>
757 inline auto Dispatch(Func func)
const;
759 virtual std::size_t GetNumTree()
const = 0;
760 virtual void SetTreeLimit(std::size_t limit) = 0;
761 virtual void DumpAsJSON(std::ostream& fo,
bool pretty_print)
const = 0;
763 inline std::string DumpAsJSON(
bool pretty_print)
const {
764 std::ostringstream oss;
765 DumpAsJSON(oss, pretty_print);
781 TREELITE_DLL_EXPORT std::vector<PyBufferFrame> GetPyBuffer();
782 TREELITE_DLL_EXPORT
static std::unique_ptr<Model>
783 CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
786 void SerializeToFile(FILE* dest_fp);
787 static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
793 int32_t num_feature{0};
797 bool average_tree_output{
false};
806 uint64_t num_tree_{0};
808 int32_t num_opt_field_per_model_{0};
815 TypeInfo threshold_type_{TypeInfo::kInvalid};
816 TypeInfo leaf_output_type_{TypeInfo::kInvalid};
818 virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
819 virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
822 virtual std::vector<PyBufferFrame>::iterator InitFromPyBuffer(
823 std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) = 0;
824 virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
825 template <
typename HeaderPrimitiveFieldHandlerFunc>
826 inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
827 template <
typename HeaderPrimitiveFieldHandlerFunc>
828 inline static void DeserializeTemplate(
829 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
830 int32_t& major_ver, int32_t& minor_ver, int32_t& patch_ver,
834 template <
typename ThresholdType,
typename LeafOutputType>
838 std::vector<Tree<ThresholdType, LeafOutputType>>
trees;
848 void DumpAsJSON(std::ostream& fo,
bool pretty_print) const override;
849 inline std::
size_t GetNumTree()
const override {
852 void SetTreeLimit(std::size_t limit)
override {
853 return trees.resize(limit);
856 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest)
override;
857 inline void SerializeToFileImpl(FILE* dest_fp)
override;
860 inline std::vector<PyBufferFrame>::iterator InitFromPyBuffer(
861 std::vector<PyBufferFrame>::iterator it, std::size_t num_frame)
override;
862 inline void DeserializeFromFileImpl(FILE* src_fp)
override;
865 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
866 typename TreeHandlerFunc>
867 inline void SerializeTemplate(
868 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
869 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
870 TreeHandlerFunc tree_handler);
871 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc,
872 typename SkipOptFieldHandlerFunc>
873 inline void DeserializeTemplate(
875 HeaderFieldHandlerFunc header_field_handler,
876 TreeHandlerFunc tree_handler,
877 SkipOptFieldHandlerFunc skip_opt_field_handler);
884 #endif // TREELITE_TREE_H_ SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
SplitFeatureType
feature split type
std::int32_t cleft_
pointer to left and right children
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
bool HasCategoricalSplit() const
Query whether this tree contains any categorical splits.
TaskType
Enum type representing the task type.
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool data_count_present_
whether data_count_ field is present
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
TaskType task_type
Task type.
TypeInfo
Types used by thresholds and leaf outputs.
store either leaf value or decision threshold
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
std::uint32_t SplitIndex(int nid) const
feature index of the node's split condition
double SumHess(int nid) const
get hessian sum
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Model()
disable copy; use default move
std::uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
int LeftChild(int nid) const
Getters.
defines configuration macros of Treelite
std::uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool sum_hess_present_
whether sum_hess_ field is present
thin wrapper for tree ensemble model
std::uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
int LeftChild() const
Getters.
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
std::vector< std::uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
Info info_
storage for leaf value or decision threshold
Operator
comparison operators