7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 11 #include <treelite/version.h> 19 #include <type_traits> 26 #define __TREELITE_STR(x) #x 27 #define _TREELITE_STR(x) __TREELITE_STR(x) 29 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 42 static_assert(std::is_pod<PyBufferFrame>::value,
"PyBufferFrame must be a POD type");
55 inline void UseForeignBuffer(
void* prealloc_buf, std::size_t size);
57 inline const T* Data()
const;
59 inline const T* End()
const;
61 inline const T& Back()
const;
62 inline std::size_t Size()
const;
63 inline void Reserve(std::size_t newsize);
64 inline void Resize(std::size_t newsize);
65 inline void Resize(std::size_t newsize, T t);
67 inline void PushBack(T t);
68 inline void Extend(
const std::vector<T>& other);
70 inline T& operator[](std::size_t idx);
71 inline const T& operator[](std::size_t idx)
const;
73 inline T& at(std::size_t idx);
74 inline const T& at(std::size_t idx)
const;
76 inline T& at(
int idx);
77 inline const T& at(
int idx)
const;
78 static_assert(std::is_pod<T>::value,
"T must be POD");
83 std::size_t capacity_;
120 kMultiClfGrovePerClass = 1,
136 kMultiClfProbDistLeaf = 2,
153 kMultiClfCategLeaf = 3
158 enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
186 static_assert(std::is_pod<TaskParam>::value,
"TaskParameter must be POD type");
189 template <
typename ThresholdType,
typename LeafOutputType>
199 LeafOutputType leaf_value;
200 ThresholdType threshold;
243 bool categories_list_right_child_;
246 static_assert(std::is_pod<Node>::value,
"Node must be a POD type");
247 static_assert(std::is_same<ThresholdType, float>::value
248 || std::is_same<ThresholdType, double>::value,
249 "ThresholdType must be either float32 or float64");
250 static_assert(std::is_same<LeafOutputType, uint32_t>::value
251 || std::is_same<LeafOutputType, float>::value
252 || std::is_same<LeafOutputType, double>::value,
253 "LeafOutputType must be one of uint32_t, float32 or float64");
254 static_assert(std::is_same<ThresholdType, LeafOutputType>::value
255 || std::is_same<LeafOutputType, uint32_t>::value,
256 "Unsupported combination of ThresholdType and LeafOutputType");
257 static_assert((std::is_same<ThresholdType, float>::value &&
sizeof(
Node) == 48)
258 || (std::is_same<ThresholdType, double>::value &&
sizeof(
Node) == 56),
259 "Node size incorrect");
264 Tree& operator=(
const Tree&) =
delete;
266 Tree& operator=(
Tree&&) noexcept = default;
268 inline
Tree<ThresholdType, LeafOutputType> Clone() const;
270 inline const
char* GetFormatStringForNode();
272 inline
void SerializeToFile(FILE* dest_fp);
273 inline
void InitFromPyBuffer(std::vector<
PyBufferFrame>::iterator begin,
275 inline
void DeserializeFromFile(FILE* src_fp);
285 template <typename WriterType, typename X, typename Y>
286 friend
void SerializeTreeToJSON(WriterType& writer, const
Tree<X, Y>& tree);
289 inline
int AllocNode();
292 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
294 SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
295 CompositeArrayHandler composite_array_handler);
296 template <typename ScalarHandler, typename ArrayHandler>
298 DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
309 inline
void AddChilds(
int nid);
315 inline std::vector<
unsigned> GetCategoricalFeatures() const;
322 inline
int LeftChild(
int nid)
const {
323 return nodes_.at(nid).cleft_;
330 return nodes_.at(nid).cright_;
337 return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
344 return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
351 return (nodes_.at(nid).sindex_ >> 31U) != 0;
358 return nodes_.at(nid).cleft_ == -1;
365 return (nodes_.at(nid).info_).leaf_value;
371 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
372 const std::size_t offset_begin = leaf_vector_offset_.at(nid);
373 const std::size_t offset_end = leaf_vector_offset_.at(nid + 1);
374 if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
376 return std::vector<LeafOutputType>();
378 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
379 &leaf_vector_[offset_end]);
388 return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
395 return (nodes_.at(nid).info_).threshold;
402 return nodes_.at(nid).cmp_;
413 const std::size_t offset_begin = matching_categories_offset_.at(nid);
414 const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
415 if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
418 return std::vector<uint32_t>();
420 return std::vector<uint32_t>(&matching_categories_[offset_begin],
421 &matching_categories_[offset_end]);
431 return matching_categories_offset_.at(nid) != matching_categories_offset_.at(nid + 1);
438 return nodes_.at(nid).split_type_;
445 return nodes_.at(nid).data_count_present_;
452 return nodes_.at(nid).data_count_;
460 return nodes_.at(nid).sum_hess_present_;
467 return nodes_.at(nid).sum_hess_;
474 return nodes_.at(nid).gain_present_;
480 inline double Gain(
int nid)
const {
481 return nodes_.at(nid).gain_;
489 return nodes_.at(nid).categories_list_right_child_;
502 inline void SetNumericalSplit(
int nid,
unsigned split_index, ThresholdType threshold,
516 inline void SetCategoricalSplit(
int nid,
unsigned split_index,
bool default_left,
517 const std::vector<uint32_t>& categories_list,
518 bool categories_list_right_child);
524 inline void SetLeaf(
int nid, LeafOutputType value);
530 inline void SetLeafVector(
int nid,
const std::vector<LeafOutputType>& leaf_vector);
537 Node& node = nodes_.at(nid);
538 node.sum_hess_ = sum_hess;
539 node.sum_hess_present_ =
true;
547 Node& node = nodes_.at(nid);
548 node.data_count_ = data_count;
549 node.data_count_present_ =
true;
557 Node& node = nodes_.at(nid);
559 node.gain_present_ =
true;
585 char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
603 ModelParam() : sigmoid_alpha(1.0f), global_bias(0.0f) {
604 std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH *
sizeof(
char));
605 std::strncpy(pred_transform,
"identity",
sizeof(pred_transform));
613 template<
typename Container>
614 inline std::vector<std::pair<std::string, std::string>>
615 InitAllowUnknown(
const Container &kwargs);
616 inline std::map<std::string, std::string> __DICT__()
const;
619 static_assert(std::is_standard_layout<ModelParam>::value,
620 "ModelParam must be in the standard layout");
622 inline void InitParamAndCheck(
ModelParam* param,
623 const std::vector<std::pair<std::string, std::string>>& cfg);
629 Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
630 patch_ver_(TREELITE_VER_PATCH) {}
631 virtual ~
Model() =
default;
637 template <
typename ThresholdType,
typename LeafOutputType>
638 inline static std::unique_ptr<Model> Create();
639 inline static std::unique_ptr<Model> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type);
640 inline TypeInfo GetThresholdType()
const {
641 return threshold_type_;
643 inline TypeInfo GetLeafOutputType()
const {
644 return leaf_output_type_;
646 template <
typename Func>
647 inline auto Dispatch(Func func);
648 template <
typename Func>
649 inline auto Dispatch(Func func)
const;
651 virtual std::size_t GetNumTree()
const = 0;
652 virtual void SetTreeLimit(std::size_t limit) = 0;
653 virtual void SerializeToJSON(std::ostream& fo)
const = 0;
656 std::vector<PyBufferFrame> GetPyBuffer();
657 static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
660 void SerializeToFile(FILE* dest_fp);
661 static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
678 int major_ver_, minor_ver_, patch_ver_;
682 virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
683 virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
684 virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
685 std::vector<PyBufferFrame>::iterator end) = 0;
686 virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
687 template <
typename HeaderPrimitiveFieldHandlerFunc>
688 inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
689 template <
typename HeaderPrimitiveFieldHandlerFunc>
690 inline static void DeserializeTemplate(
691 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
695 template <
typename ThresholdType,
typename LeafOutputType>
699 std::vector<Tree<ThresholdType, LeafOutputType>>
trees;
709 void SerializeToJSON(std::ostream& fo) const override;
710 inline std::
size_t GetNumTree()
const override {
713 void SetTreeLimit(std::size_t limit)
override {
714 return trees.resize(limit);
717 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest)
override;
718 inline void SerializeToFileImpl(FILE* dest_fp)
override;
719 inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
720 std::vector<PyBufferFrame>::iterator end)
override;
721 inline void DeserializeFromFileImpl(FILE* src_fp)
override;
724 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
725 typename TreeHandlerFunc>
726 inline void SerializeTemplate(
727 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
728 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
729 TreeHandlerFunc tree_handler);
730 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc>
731 inline void DeserializeTemplate(
733 HeaderFieldHandlerFunc header_field_handler,
734 TreeHandlerFunc tree_handler);
741 #endif // TREELITE_TREE_H_ ModelParam param
extra parameters
SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
SplitFeatureType
feature split type
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
TaskType
Enum type representing the task type.
bool average_tree_output
whether to average tree outputs
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool data_count_present_
whether data_count_ field is present
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
int32_t cleft_
pointer to left and right children
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
TaskType task_type
Task type.
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
TypeInfo
Types used by thresholds and leaf outputs.
store either leaf value or decision threshold
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
double SumHess(int nid) const
get hessian sum
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Model()
disable copy; use default move
TaskParam task_param
Group of parameters that are specific to the particular task type.
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool HasMatchingCategories(int nid) const
tests whether the node has a non-empty list for matching categories. See MatchingCategories() for the...
bool sum_hess_present_
whether sum_hess_ field is present
thin wrapper for tree ensemble model
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Info info_
storage for leaf value or decision threshold
Operator
comparison operators