7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 11 #include <treelite/version.h> 20 #include <type_traits> 27 #define __TREELITE_STR(x) #x 28 #define _TREELITE_STR(x) __TREELITE_STR(x) 30 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 34 template <
typename ThresholdType,
typename LeafOutputType>
46 static_assert(std::is_pod<PyBufferFrame>::value,
"PyBufferFrame must be a POD type");
59 inline void UseForeignBuffer(
void* prealloc_buf, std::size_t size);
61 inline const T* Data()
const;
63 inline const T* End()
const;
65 inline const T& Back()
const;
66 inline std::size_t Size()
const;
67 inline bool Empty()
const;
68 inline void Reserve(std::size_t newsize);
69 inline void Resize(std::size_t newsize);
70 inline void Resize(std::size_t newsize, T t);
72 inline void PushBack(T t);
73 inline void Extend(
const std::vector<T>& other);
75 inline T& operator[](std::size_t idx);
76 inline const T& operator[](std::size_t idx)
const;
78 inline T& at(std::size_t idx);
79 inline const T& at(std::size_t idx)
const;
81 inline T& at(
int idx);
82 inline const T& at(
int idx)
const;
83 static_assert(std::is_pod<T>::value,
"T must be POD");
88 std::size_t capacity_;
125 kMultiClfGrovePerClass = 1,
141 kMultiClfProbDistLeaf = 2,
158 kMultiClfCategLeaf = 3
161 inline std::string TaskTypeToString(
TaskType type) {
163 case TaskType::kBinaryClfRegr:
return "BinaryClfRegr";
164 case TaskType::kMultiClfGrovePerClass:
return "MultiClfGrovePerClass";
165 case TaskType::kMultiClfProbDistLeaf:
return "MultiClfProbDistLeaf";
166 case TaskType::kMultiClfCategLeaf:
return "MultiClfCategLeaf";
173 enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
201 inline std::string OutputTypeToString(TaskParam::OutputType type) {
203 case TaskParam::OutputType::kFloat:
return "float";
204 case TaskParam::OutputType::kInt:
return "int";
209 static_assert(std::is_pod<TaskParam>::value,
"TaskParameter must be POD type");
212 template <
typename ThresholdType,
typename LeafOutputType>
222 LeafOutputType leaf_value;
223 ThresholdType threshold;
266 bool categories_list_right_child_;
269 static_assert(std::is_pod<Node>::value,
"Node must be a POD type");
270 static_assert(std::is_same<ThresholdType, float>::value
271 || std::is_same<ThresholdType, double>::value,
272 "ThresholdType must be either float32 or float64");
273 static_assert(std::is_same<LeafOutputType, uint32_t>::value
274 || std::is_same<LeafOutputType, float>::value
275 || std::is_same<LeafOutputType, double>::value,
276 "LeafOutputType must be one of uint32_t, float32 or float64");
277 static_assert(std::is_same<ThresholdType, LeafOutputType>::value
278 || std::is_same<LeafOutputType, uint32_t>::value,
279 "Unsupported combination of ThresholdType and LeafOutputType");
280 static_assert((std::is_same<ThresholdType, float>::value &&
sizeof(
Node) == 48)
281 || (std::is_same<ThresholdType, double>::value &&
sizeof(
Node) == 56),
282 "Node size incorrect");
287 Tree& operator=(
const Tree&) =
delete;
289 Tree& operator=(
Tree&&) noexcept = default;
291 inline
Tree<ThresholdType, LeafOutputType> Clone() const;
293 inline const
char* GetFormatStringForNode();
295 inline
void SerializeToFile(FILE* dest_fp);
296 inline
void InitFromPyBuffer(std::vector<
PyBufferFrame>::iterator begin,
298 inline
void DeserializeFromFile(FILE* src_fp);
312 template <typename WriterType, typename X, typename Y>
313 friend
void DumpModelAsJSON(WriterType& writer, const
ModelImpl<X, Y>& model);
314 template <typename WriterType, typename X, typename Y>
315 friend
void DumpTreeAsJSON(WriterType& writer, const
Tree<X, Y>& tree);
318 inline
int AllocNode();
321 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
323 SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
324 CompositeArrayHandler composite_array_handler);
325 template <typename ScalarHandler, typename ArrayHandler>
327 DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
338 inline
void AddChilds(
int nid);
344 inline std::vector<
unsigned> GetCategoricalFeatures() const;
351 inline
int LeftChild(
int nid)
const {
352 return nodes_.at(nid).cleft_;
359 return nodes_.at(nid).cright_;
366 return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
373 return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
380 return (nodes_.at(nid).sindex_ >> 31U) != 0;
387 return nodes_.at(nid).cleft_ == -1;
394 return (nodes_.at(nid).info_).leaf_value;
400 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
401 const std::size_t offset_begin = leaf_vector_begin_.at(nid);
402 const std::size_t offset_end = leaf_vector_end_.at(nid);
403 if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
405 return std::vector<LeafOutputType>();
407 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
408 &leaf_vector_[offset_end]);
417 return leaf_vector_begin_.at(nid) != leaf_vector_end_.at(nid);
424 return (nodes_.at(nid).info_).threshold;
431 return nodes_.at(nid).cmp_;
442 const std::size_t offset_begin = matching_categories_offset_.at(nid);
443 const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
444 if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
447 return std::vector<uint32_t>();
449 return std::vector<uint32_t>(&matching_categories_[offset_begin],
450 &matching_categories_[offset_end]);
459 return nodes_.at(nid).split_type_;
466 return nodes_.at(nid).data_count_present_;
473 return nodes_.at(nid).data_count_;
481 return nodes_.at(nid).sum_hess_present_;
488 return nodes_.at(nid).sum_hess_;
495 return nodes_.at(nid).gain_present_;
501 inline double Gain(
int nid)
const {
502 return nodes_.at(nid).gain_;
510 return nodes_.at(nid).categories_list_right_child_;
523 inline void SetNumericalSplit(
int nid,
unsigned split_index, ThresholdType threshold,
537 inline void SetCategoricalSplit(
int nid,
unsigned split_index,
bool default_left,
538 const std::vector<uint32_t>& categories_list,
539 bool categories_list_right_child);
545 inline void SetLeaf(
int nid, LeafOutputType value);
551 inline void SetLeafVector(
int nid,
const std::vector<LeafOutputType>& leaf_vector);
558 Node& node = nodes_.at(nid);
559 node.sum_hess_ = sum_hess;
560 node.sum_hess_present_ =
true;
568 Node& node = nodes_.at(nid);
569 node.data_count_ = data_count;
570 node.data_count_present_ =
true;
578 Node& node = nodes_.at(nid);
580 node.gain_present_ =
true;
606 char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
632 ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
633 std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH *
sizeof(
char));
634 std::strncpy(pred_transform,
"identity",
sizeof(pred_transform));
642 template<
typename Container>
643 inline std::vector<std::pair<std::string, std::string>>
644 InitAllowUnknown(
const Container &kwargs);
645 inline std::map<std::string, std::string> __DICT__()
const;
648 static_assert(std::is_standard_layout<ModelParam>::value,
649 "ModelParam must be in the standard layout");
651 inline void InitParamAndCheck(
ModelParam* param,
652 const std::vector<std::pair<std::string, std::string>>& cfg);
658 Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
659 patch_ver_(TREELITE_VER_PATCH) {}
660 virtual ~
Model() =
default;
666 template <
typename ThresholdType,
typename LeafOutputType>
667 inline static std::unique_ptr<Model> Create();
668 inline static std::unique_ptr<Model> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type);
669 inline TypeInfo GetThresholdType()
const {
670 return threshold_type_;
672 inline TypeInfo GetLeafOutputType()
const {
673 return leaf_output_type_;
675 template <
typename Func>
676 inline auto Dispatch(Func func);
677 template <
typename Func>
678 inline auto Dispatch(Func func)
const;
680 virtual std::size_t GetNumTree()
const = 0;
681 virtual void SetTreeLimit(std::size_t limit) = 0;
682 virtual void DumpAsJSON(std::ostream& fo,
bool pretty_print)
const = 0;
684 inline std::string DumpAsJSON(
bool pretty_print)
const {
685 std::ostringstream oss;
686 DumpAsJSON(oss, pretty_print);
691 std::vector<PyBufferFrame> GetPyBuffer();
692 static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
695 void SerializeToFile(FILE* dest_fp);
696 static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
713 int major_ver_, minor_ver_, patch_ver_;
717 virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
718 virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
719 virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
720 std::vector<PyBufferFrame>::iterator end) = 0;
721 virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
722 template <
typename HeaderPrimitiveFieldHandlerFunc>
723 inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
724 template <
typename HeaderPrimitiveFieldHandlerFunc>
725 inline static void DeserializeTemplate(
726 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
730 template <
typename ThresholdType,
typename LeafOutputType>
734 std::vector<Tree<ThresholdType, LeafOutputType>>
trees;
744 void DumpAsJSON(std::ostream& fo,
bool pretty_print) const override;
745 inline std::
size_t GetNumTree()
const override {
748 void SetTreeLimit(std::size_t limit)
override {
749 return trees.resize(limit);
752 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest)
override;
753 inline void SerializeToFileImpl(FILE* dest_fp)
override;
754 inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
755 std::vector<PyBufferFrame>::iterator end)
override;
756 inline void DeserializeFromFileImpl(FILE* src_fp)
override;
759 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
760 typename TreeHandlerFunc>
761 inline void SerializeTemplate(
762 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
763 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
764 TreeHandlerFunc tree_handler);
765 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc>
766 inline void DeserializeTemplate(
768 HeaderFieldHandlerFunc header_field_handler,
769 TreeHandlerFunc tree_handler);
776 #endif // TREELITE_TREE_H_ ModelParam param
extra parameters
SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
SplitFeatureType
feature split type
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
TaskType
Enum type representing the task type.
bool average_tree_output
whether to average tree outputs
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
bool data_count_present_
whether data_count_ field is present
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
int32_t cleft_
pointer to left and right children
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
TaskType task_type
Task type.
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
TypeInfo
Types used by thresholds and leaf outputs.
store either leaf value or decision threshold
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
double SumHess(int nid) const
get hessian sum
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Model()
disable copy; use default move
TaskParam task_param
Group of parameters that are specific to the particular task type.
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool sum_hess_present_
whether sum_hess_ field is present
thin wrapper for tree ensemble model
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Info info_
storage for leaf value or decision threshold
Operator
comparison operators