7 #ifndef TREELITE_TREE_H_ 8 #define TREELITE_TREE_H_ 11 #include <treelite/version.h> 21 #include <type_traits> 28 #define __TREELITE_STR(x) #x 29 #define _TREELITE_STR(x) __TREELITE_STR(x) 31 #define TREELITE_MAX_PRED_TRANSFORM_LENGTH 256 35 template <
typename ThresholdType,
typename LeafOutputType>
47 static_assert(std::is_pod<PyBufferFrame>::value,
"PyBufferFrame must be a POD type");
60 inline void UseForeignBuffer(
void* prealloc_buf, std::size_t size);
62 inline const T* Data()
const;
64 inline const T* End()
const;
66 inline const T& Back()
const;
67 inline std::size_t Size()
const;
68 inline bool Empty()
const;
69 inline void Reserve(std::size_t newsize);
70 inline void Resize(std::size_t newsize);
71 inline void Resize(std::size_t newsize, T t);
73 inline void PushBack(T t);
74 inline void Extend(
const std::vector<T>& other);
76 inline T& operator[](std::size_t idx);
77 inline const T& operator[](std::size_t idx)
const;
79 inline T& at(std::size_t idx);
80 inline const T& at(std::size_t idx)
const;
82 inline T& at(
int idx);
83 inline const T& at(
int idx)
const;
84 static_assert(std::is_pod<T>::value,
"T must be POD");
89 std::size_t capacity_;
126 kMultiClfGrovePerClass = 1,
142 kMultiClfProbDistLeaf = 2,
159 kMultiClfCategLeaf = 3
162 inline std::string TaskTypeToString(
TaskType type) {
164 case TaskType::kBinaryClfRegr:
return "BinaryClfRegr";
165 case TaskType::kMultiClfGrovePerClass:
return "MultiClfGrovePerClass";
166 case TaskType::kMultiClfProbDistLeaf:
return "MultiClfProbDistLeaf";
167 case TaskType::kMultiClfCategLeaf:
return "MultiClfCategLeaf";
174 enum class OutputType : uint8_t { kFloat = 0, kInt = 1 };
202 inline std::string OutputTypeToString(TaskParam::OutputType type) {
204 case TaskParam::OutputType::kFloat:
return "float";
205 case TaskParam::OutputType::kInt:
return "int";
210 static_assert(std::is_pod<TaskParam>::value,
"TaskParameter must be POD type");
213 template <
typename ThresholdType,
typename LeafOutputType>
223 LeafOutputType leaf_value;
224 ThresholdType threshold;
267 bool categories_list_right_child_;
270 static_assert(std::is_pod<Node>::value,
"Node must be a POD type");
271 static_assert(std::is_same<ThresholdType, float>::value
272 || std::is_same<ThresholdType, double>::value,
273 "ThresholdType must be either float32 or float64");
274 static_assert(std::is_same<LeafOutputType, uint32_t>::value
275 || std::is_same<LeafOutputType, float>::value
276 || std::is_same<LeafOutputType, double>::value,
277 "LeafOutputType must be one of uint32_t, float32 or float64");
278 static_assert(std::is_same<ThresholdType, LeafOutputType>::value
279 || std::is_same<LeafOutputType, uint32_t>::value,
280 "Unsupported combination of ThresholdType and LeafOutputType");
281 static_assert((std::is_same<ThresholdType, float>::value &&
sizeof(
Node) == 48)
282 || (std::is_same<ThresholdType, double>::value &&
sizeof(
Node) == 56),
283 "Node size incorrect");
288 Tree& operator=(
const Tree&) =
delete;
290 Tree& operator=(
Tree&&) noexcept = default;
292 inline
Tree<ThresholdType, LeafOutputType> Clone() const;
294 inline const
char* GetFormatStringForNode();
296 inline
void SerializeToFile(FILE* dest_fp);
297 inline
void InitFromPyBuffer(std::vector<
PyBufferFrame>::iterator begin,
299 inline
void DeserializeFromFile(FILE* src_fp);
313 template <typename WriterType, typename X, typename Y>
314 friend
void DumpModelAsJSON(WriterType& writer, const
ModelImpl<X, Y>& model);
315 template <typename WriterType, typename X, typename Y>
316 friend
void DumpTreeAsJSON(WriterType& writer, const
Tree<X, Y>& tree);
319 inline
int AllocNode();
322 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
324 SerializeTemplate(ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
325 CompositeArrayHandler composite_array_handler);
326 template <typename ScalarHandler, typename ArrayHandler>
328 DeserializeTemplate(ScalarHandler scalar_handler, ArrayHandler array_handler);
339 inline
void AddChilds(
int nid);
345 inline std::vector<
unsigned> GetCategoricalFeatures() const;
352 inline
int LeftChild(
int nid)
const {
353 return nodes_.at(nid).cleft_;
360 return nodes_.at(nid).cright_;
367 return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
374 return (nodes_.at(nid).sindex_ & ((1U << 31U) - 1U));
381 return (nodes_.at(nid).sindex_ >> 31U) != 0;
388 return nodes_.at(nid).cleft_ == -1;
395 return (nodes_.at(nid).info_).leaf_value;
401 inline std::vector<LeafOutputType>
LeafVector(
int nid)
const {
402 const std::size_t offset_begin = leaf_vector_begin_.at(nid);
403 const std::size_t offset_end = leaf_vector_end_.at(nid);
404 if (offset_begin >= leaf_vector_.Size() || offset_end > leaf_vector_.Size()) {
406 return std::vector<LeafOutputType>();
408 return std::vector<LeafOutputType>(&leaf_vector_[offset_begin],
409 &leaf_vector_[offset_end]);
418 return leaf_vector_begin_.at(nid) != leaf_vector_end_.at(nid);
425 return (nodes_.at(nid).info_).threshold;
432 return nodes_.at(nid).cmp_;
443 const std::size_t offset_begin = matching_categories_offset_.at(nid);
444 const std::size_t offset_end = matching_categories_offset_.at(nid + 1);
445 if (offset_begin >= matching_categories_.Size() || offset_end > matching_categories_.Size()) {
448 return std::vector<uint32_t>();
450 return std::vector<uint32_t>(&matching_categories_[offset_begin],
451 &matching_categories_[offset_end]);
460 if (matching_categories_.Empty()) {
464 matching_categories_.End())};
471 return nodes_.at(nid).split_type_;
478 return nodes_.at(nid).data_count_present_;
485 return nodes_.at(nid).data_count_;
493 return nodes_.at(nid).sum_hess_present_;
500 return nodes_.at(nid).sum_hess_;
507 return nodes_.at(nid).gain_present_;
513 inline double Gain(
int nid)
const {
514 return nodes_.at(nid).gain_;
522 return nodes_.at(nid).categories_list_right_child_;
535 inline void SetNumericalSplit(
int nid,
unsigned split_index, ThresholdType threshold,
549 inline void SetCategoricalSplit(
int nid,
unsigned split_index,
bool default_left,
550 const std::vector<uint32_t>& categories_list,
551 bool categories_list_right_child);
557 inline void SetLeaf(
int nid, LeafOutputType value);
563 inline void SetLeafVector(
int nid,
const std::vector<LeafOutputType>& leaf_vector);
570 Node& node = nodes_.at(nid);
571 node.sum_hess_ = sum_hess;
572 node.sum_hess_present_ =
true;
580 Node& node = nodes_.at(nid);
581 node.data_count_ = data_count;
582 node.data_count_present_ =
true;
590 Node& node = nodes_.at(nid);
592 node.gain_present_ =
true;
618 char pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH] = {0};
644 ModelParam() : sigmoid_alpha(1.0f), ratio_c(1.0f), global_bias(0.0f) {
645 std::memset(pred_transform, 0, TREELITE_MAX_PRED_TRANSFORM_LENGTH *
sizeof(
char));
646 std::strncpy(pred_transform,
"identity",
sizeof(pred_transform));
654 template<
typename Container>
655 inline std::vector<std::pair<std::string, std::string>>
656 InitAllowUnknown(
const Container &kwargs);
657 inline std::map<std::string, std::string> __DICT__()
const;
660 static_assert(std::is_standard_layout<ModelParam>::value,
661 "ModelParam must be in the standard layout");
663 inline void InitParamAndCheck(
ModelParam* param,
664 const std::vector<std::pair<std::string, std::string>>& cfg);
670 Model() : major_ver_(TREELITE_VER_MAJOR), minor_ver_(TREELITE_VER_MINOR),
671 patch_ver_(TREELITE_VER_PATCH) {}
672 virtual ~
Model() =
default;
678 template <
typename ThresholdType,
typename LeafOutputType>
679 inline static std::unique_ptr<Model> Create();
680 inline static std::unique_ptr<Model> Create(
TypeInfo threshold_type,
TypeInfo leaf_output_type);
681 inline TypeInfo GetThresholdType()
const {
682 return threshold_type_;
684 inline TypeInfo GetLeafOutputType()
const {
685 return leaf_output_type_;
687 template <
typename Func>
688 inline auto Dispatch(Func func);
689 template <
typename Func>
690 inline auto Dispatch(Func func)
const;
692 virtual std::size_t GetNumTree()
const = 0;
693 virtual void SetTreeLimit(std::size_t limit) = 0;
694 virtual void DumpAsJSON(std::ostream& fo,
bool pretty_print)
const = 0;
696 inline std::string DumpAsJSON(
bool pretty_print)
const {
697 std::ostringstream oss;
698 DumpAsJSON(oss, pretty_print);
703 std::vector<PyBufferFrame> GetPyBuffer();
704 static std::unique_ptr<Model> CreateFromPyBuffer(std::vector<PyBufferFrame> frames);
707 void SerializeToFile(FILE* dest_fp);
708 static std::unique_ptr<Model> DeserializeFromFile(FILE* src_fp);
725 int major_ver_, minor_ver_, patch_ver_;
729 virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
730 virtual void SerializeToFileImpl(FILE* dest_fp) = 0;
731 virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
732 std::vector<PyBufferFrame>::iterator end) = 0;
733 virtual void DeserializeFromFileImpl(FILE* src_fp) = 0;
734 template <
typename HeaderPrimitiveFieldHandlerFunc>
735 inline void SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler);
736 template <
typename HeaderPrimitiveFieldHandlerFunc>
737 inline static void DeserializeTemplate(
738 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
742 template <
typename ThresholdType,
typename LeafOutputType>
746 std::vector<Tree<ThresholdType, LeafOutputType>>
trees;
756 void DumpAsJSON(std::ostream& fo,
bool pretty_print) const override;
757 inline std::
size_t GetNumTree()
const override {
760 void SetTreeLimit(std::size_t limit)
override {
761 return trees.resize(limit);
764 inline void GetPyBuffer(std::vector<PyBufferFrame>* dest)
override;
765 inline void SerializeToFileImpl(FILE* dest_fp)
override;
766 inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
767 std::vector<PyBufferFrame>::iterator end)
override;
768 inline void DeserializeFromFileImpl(FILE* src_fp)
override;
771 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
772 typename TreeHandlerFunc>
773 inline void SerializeTemplate(
774 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
775 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
776 TreeHandlerFunc tree_handler);
777 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc>
778 inline void DeserializeTemplate(
780 HeaderFieldHandlerFunc header_field_handler,
781 TreeHandlerFunc tree_handler);
788 #endif // TREELITE_TREE_H_ ModelParam param
extra parameters
SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
Implementation for tree.h.
bool gain_present_
whether gain_present_ field is present
SplitFeatureType
feature split type
Backport of std::optional from C++17.
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
bool HasDataCount(int nid) const
test whether this node has data count
bool HasGain(int nid) const
test whether this node has gain value
bool grove_per_class
Whether we designate a subset of the trees to compute the prediction for each class.
Group of parameters that are dependent on the choice of the task type.
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< LeafOutputType > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
TaskType
Enum type representing the task type.
bool average_tree_output
whether to average tree outputs
float sigmoid_alpha
scaling parameter for sigmoid function sigmoid(x) = 1 / (1 + exp(-alpha * x))
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
optional< uint32_t > MaxCategory() const
Get the largest category value used in all categorical splits in this tree. If there are no categoric...
bool data_count_present_
whether data_count_ field is present
unsigned int leaf_vector_size
Dimension of the output from each leaf node.
int32_t cleft_
pointer to left and right children
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
unsigned int num_class
The number of classes in the target label.
float global_bias
global bias of the model
double gain_
change in loss that is attributed to a particular split
TaskType task_type
Task type.
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
TypeInfo
Types used by thresholds and leaf outputs.
store either leaf value or decision threshold
std::vector< Tree< ThresholdType, LeafOutputType > > trees
member trees
double SumHess(int nid) const
get hessian sum
void SetGain(int nid, double gain)
set the gain value of the node
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
bool CategoriesListRightChild(int nid) const
test whether the list given by MatchingCategories(nid) is associated with the right child node or the...
SplitFeatureType SplitType(int nid) const
get feature split type
LeafOutputType LeafValue(int nid) const
get leaf value of the leaf node
float ratio_c
scaling parameter for exponential standard ratio transformation expstdratio(x) = exp2(-x / c) ...
Model()
disable copy; use default move
TaskParam task_param
Group of parameters that are specific to the particular task type.
defines configuration macros of Treelite
std::vector< uint32_t > MatchingCategories(int nid) const
Get list of all categories belonging to the left/right child node. See the categories_list_right_chil...
uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
OutputType output_type
The type of output from each leaf node.
bool sum_hess_present_
whether sum_hess_ field is present
thin wrapper for tree ensemble model
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
ThresholdType Threshold(int nid) const
get threshold of the node
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
Info info_
storage for leaf value or decision threshold
Operator
comparison operators