7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 11 #include <treelite/version.h> 18 #include <type_traits> 25 #define __TREELITE_STR(x) #x 26 #define _TREELITE_STR(x) __TREELITE_STR(x) 28 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 34 float stof(
const std::string& value, std::size_t* pos);
49 static_assert(std::is_pod<PyBufferFrame>::value,
"PyBufferFrame must be a POD type");
62 inline void UseForeignBuffer(
void* prealloc_buf, std::size_t size);
64 inline const T* Data()
const;
66 inline const T* End()
const;
68 inline const T& Back()
const;
69 inline std::size_t Size()
const;
70 inline void Reserve(std::size_t newsize);
71 inline void Resize(std::size_t newsize);
72 inline void Resize(std::size_t newsize, T t);
74 inline void PushBack(T t);
75 inline void Extend(
const std::vector<T>& other);
77 inline T& operator[](std::size_t idx);
78 inline const T& operator[](std::size_t idx)
const;
80 inline T& at(std::size_t idx);
81 inline const T& at(std::size_t idx)
const;
83 inline T& at(
int idx);
84 inline const T& at(
int idx)
const;
85 static_assert(std::is_pod<T>::value,
"T must be POD");
90 std::size_t capacity_;
127 kMultiClfGrovePerClass = 1,
143 kMultiClfProbDistLeaf = 2,
160 kMultiClfCategLeaf = 3
165 enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
193 static_assert(std::is_pod<TaskParameter>::value,
"TaskParameter must be POD type");
196 template <
typename ThresholdType,
typename LeafOutputType>
206 LeafOutputType leaf_value;
207 ThresholdType threshold;
250 bool categories_list_right_child_;
253 static_assert(std::is_pod<Node>::value,
"Node must be a POD type");
254 static_assert(std::is_same<ThresholdType, float>::value
255 || std::is_same<ThresholdType, double>::value,
256 "ThresholdType must be either float32 or float64");
257 static_assert(std::is_same<LeafOutputType, uint32_t>::value
258 || std::is_same<LeafOutputType, float>::value
259 || std::is_same<LeafOutputType, double>::value,
260 "LeafOutputType must be one of uint32_t, float32 or float64");
261 static_assert(std::is_same<ThresholdType, LeafOutputType>::value
262 || std::is_same<LeafOutputType, uint32_t>::value,
263 "Unsupported combination of ThresholdType and LeafOutputType");
264 static_assert((std::is_same<ThresholdType, float>::value &&
sizeof(
Node) == 48)
265 || (std::is_same<ThresholdType, double>::value &&
sizeof(
Node) == 56),
266 "Node size incorrect");
271 Tree& operator=(
const Tree&) =
delete;
273 Tree& operator=(
Tree&&) noexcept = default;
275 inline
Tree<ThresholdType, LeafOutputType> Clone() const;
277 inline const
char* GetFormatStringForNode();
279 inline
void SerializeToFile(FILE* dest_fp);
280 inline
void InitFromPyBuffer(std::vector<
PyBufferFrame>::iterator begin,
282 inline
void DeserializeFromFile(FILE* src_fp);
293 inline
int AllocNode();
296 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
298 SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
299 CompositeArrayHandler composite_array_handler);
300 template <typename ScalarHandler, typename ArrayHandler>
302 DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
313 inline
void AddChilds(
int nid);
319 inline std::vector<
unsigned> GetCategoricalFeatures() const;
326 inline
int LeftChild(
int nid)
const {
327 return nodes_.at(nid).cleft_;
334 return nodes_.at(nid).cright_;
341 return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
348 return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
355 return (nodes_.at(nid).sindex_ >> 31U) != 0;
362 return nodes_.at(nid).cleft_ == -1;
369 return (nodes_.at(nid).info_).leaf_value;
375 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
376 const std::size_t offset_begin = leaf_vector_offset_.at(nid);
377 const std::size_t offset_end = leaf_vector_offset_.at(nid + 1);
378 if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
380 return std::vector<LeafOutputType>();
382 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
383 &leaf_vector_[offset_end]);
392 return leaf_vector_offset_.at(nid) != leaf_vector_offset_.at(nid + 1);
399 return (nodes_.at(nid).info_).threshold;
406 return nodes_.at(nid).cmp_;
417 const std::size_t offset_begin = matching_categories_offset_.at(nid);
418 const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
419 if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
422 return std::vector<uint32_t>();
424 return std::vector<uint32_t>(&matching_categories_[offset_begin],
425 &matching_categories_[offset_end]);
435 return matching_categories_offset_.at(nid) != matching_categories_offset_.at(nid + 1);
442 return nodes_.at(nid).split_type_;
449 return nodes_.at(nid).data_count_present_;
456 return nodes_.at(nid).data_count_;
464 return nodes_.at(nid).sum_hess_present_;
471 return nodes_.at(nid).sum_hess_;
478 return nodes_.at(nid).gain_present_;
484 inline double Gain(
int nid)
const {
485 return nodes_.at(nid).gain_;
493 return nodes_.at(nid).categories_list_right_child_;
506 inline void SetNumericalSplit(
int nid,
unsigned split_index, ThresholdType threshold,
520 inline void SetCategoricalSplit(
int nid,
unsigned split_index,
bool default_left,
521 const std::vector<uint32_t>& categories_list,
522 bool categories_list_right_child);
528 inline void SetLeaf(
int nid, LeafOutputType value);
534 inline void SetLeafVector(
int nid,
const std::vector<LeafOutputType>& leaf_vector);
541 Node& node = nodes_.at(nid);
542 node.sum_hess_ = sum_hess;
543 node.sum_hess_present_ =
true;
551 Node& node = nodes_.at(nid);
552 node.data_count_ = data_count;
553 node.data_count_present_ =
true;
561 Node& node = nodes_.at(nid);
563 node.gain_present_ =
true;
566 void ReferenceSerialize(dmlc::Stream* fo)
const;
591 char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
609 ModelParam() : sigmoid_alpha(1.0f), global_bias(0.0f) {
610 std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH *
sizeof(
char));
611 std::strncpy(pred_transform,
"identity",
sizeof(pred_transform));
619 template<
typename Container>
620 inline std::vector<std::pair<std::string, std::string>>
621 InitAllowUnknown(
const Container &kwargs);
622 inline std::map<std::string, std::string> __DICT__()
const;
625 static_assert(std::is_standard_layout<ModelParam>::value,
626 "ModelParam must be in the standard layout");
628 inline void InitParamAndCheck(
ModelParam* param,
629 const std::vector<std::pair<std::string, std::string>>& cfg);
635 Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
636 patch_ver_(TREELITE_VER_PATCH) {}
637 virtual ~
Model() =
default;
643 template <
typename ThresholdType,
typename LeafOutputType>
644 inline static std::unique_ptr<Model> Create();
645 inline static std::unique_ptr<Model> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type);
646 inline TypeInfo GetThresholdType()
const {
647 return threshold_type_;
649 inline TypeInfo GetLeafOutputType()
const {
650 return leaf_output_type_;
652 template <
typename Func>
653 inline auto Dispatch(Func func);
654 template <
typename Func>
655 inline auto Dispatch(Func func)
const;
657 virtual std::size_t GetNumTree()
const = 0;
658 virtual void SetTreeLimit(std::size_t limit) = 0;
659 virtual void ReferenceSerialize(dmlc::Stream* fo)
const = 0;
662 std::vector<PyBufferFrame> GetPyBuffer();
663 static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
666 void SerializeToFile(FILE* dest_fp);
667 static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
684 int major_ver_, minor_ver_, patch_ver_;
688 virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
689 virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
690 virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
691 std::vector<PyBufferFrame>::iterator end) = 0;
692 virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
693 template <
typename HeaderPrimitiveFieldHandlerFunc>
694 inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
695 template <
typename HeaderPrimitiveFieldHandlerFunc>
696 inline static void DeserializeTemplate(
697 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
701 template <
typename ThresholdType,
typename LeafOutputType>
705 std::vector<Tree<ThresholdType, LeafOutputType>>
trees;
715 void ReferenceSerialize(
dmlc::Stream* fo) const override;
716 inline std::
size_t GetNumTree()
const override {
719 void SetTreeLimit(std::size_t limit)
override {
720 return trees.resize(limit);
723 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest)
override;
724 inline void SerializeToFileImpl(FILE* dest_fp)
override;
725 inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
726 std::vector<PyBufferFrame>::iterator end)
override;
727 inline void DeserializeFromFileImpl(FILE* src_fp)
override;
730 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
731 typename TreeHandlerFunc>
732 inline void SerializeTemplate(
733 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
734 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
735 TreeHandlerFunc tree_handler);
736 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc>
737 inline void DeserializeTemplate(
739 HeaderFieldHandlerFunc header_field_handler,
740 TreeHandlerFunc tree_handler);
747 #endif // TREELITE_TREE_H_ ModelParam param
extra parameters
SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
SplitFeatureType
feature split type
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
TaskType
Enum type representing the task type.
bool average_tree_output
whether to average tree outputs
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool data_count_present_
whether data_count_ field is present
int32_t cleft_
pointer to left and right children
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
TaskType task_type
Task type.
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
TypeInfo
Types used by thresholds and leaf outputs.
Group of parameters that are dependent on the choice of the task type.
store either leaf value or decision threshold
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
double SumHess(int nid) const
get hessian sum
void SetGain(int nid, double gain)
set the gain value of the node
TaskParameter task_param
Group of parameters that are specific to the particular task type.
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
Model()
disable copy; use default move
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
unsigned int num_class
The number of classes in the target label.
uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
bool HasMatchingCategories(int nid) const
tests whether the node has a non-empty list for matching categories. See MatchingCategories() for the...
bool sum_hess_present_
whether sum_hess_ field is present
thin wrapper for tree ensemble model
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
OutputType output_type
The type of output from each leaf node.
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Info info_
storage for leaf value or decision threshold
Operator
comparison operators