7 #ifndef TREELITE_TREE_IMPL_H_ 8 #define TREELITE_TREE_IMPL_H_ 16 #include <unordered_map> 25 inline std::string GetString(T x) {
26 return std::to_string(x);
30 inline std::string GetString<float>(
float x) {
31 std::ostringstream oss;
32 oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
37 inline std::string GetString<double>(
double x) {
38 std::ostringstream oss;
39 oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
48 ContiguousArray<T>::ContiguousArray()
49 : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
52 ContiguousArray<T>::~ContiguousArray() {
53 if (buffer_ && owned_buffer_) {
59 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
60 : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
61 owned_buffer_(other.owned_buffer_) {
62 other.buffer_ =
nullptr;
63 other.size_ = other.capacity_ = 0;
68 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
69 if (buffer_ && owned_buffer_) {
72 buffer_ = other.buffer_;
74 capacity_ = other.capacity_;
75 owned_buffer_ = other.owned_buffer_;
76 other.buffer_ =
nullptr;
77 other.size_ = other.capacity_ = 0;
82 inline ContiguousArray<T>
83 ContiguousArray<T>::Clone()
const {
84 ContiguousArray clone;
85 clone.buffer_ =
static_cast<T*
>(std::malloc(
sizeof(T) * capacity_));
87 throw std::runtime_error(
"Could not allocate memory for the clone");
89 std::memcpy(clone.buffer_, buffer_,
sizeof(T) * size_);
91 clone.capacity_ = capacity_;
92 clone.owned_buffer_ =
true;
98 ContiguousArray<T>::UseForeignBuffer(
void* prealloc_buf,
size_t size) {
99 if (buffer_ && owned_buffer_) {
102 buffer_ =
static_cast<T*
>(prealloc_buf);
105 owned_buffer_ =
false;
108 template <
typename T>
110 ContiguousArray<T>::Data() {
114 template <
typename T>
116 ContiguousArray<T>::Data()
const {
120 template <
typename T>
122 ContiguousArray<T>::End() {
123 return &buffer_[Size()];
126 template <
typename T>
128 ContiguousArray<T>::End()
const {
129 return &buffer_[Size()];
132 template <
typename T>
134 ContiguousArray<T>::Back() {
135 return buffer_[Size() - 1];
138 template <
typename T>
140 ContiguousArray<T>::Back()
const {
141 return buffer_[Size() - 1];
144 template <
typename T>
146 ContiguousArray<T>::Size()
const {
150 template <
typename T>
152 ContiguousArray<T>::Reserve(
size_t newsize) {
153 if (!owned_buffer_) {
154 throw std::runtime_error(
"Cannot resize when using a foreign buffer; clone first");
156 T* newbuf =
static_cast<T*
>(std::realloc(static_cast<void*>(buffer_),
sizeof(T) * newsize));
158 throw std::runtime_error(
"Could not expand buffer");
164 template <
typename T>
166 ContiguousArray<T>::Resize(
size_t newsize) {
167 if (!owned_buffer_) {
168 throw std::runtime_error(
"Cannot resize when using a foreign buffer; clone first");
170 if (newsize > capacity_) {
171 size_t newcapacity = capacity_;
172 if (newcapacity == 0) {
175 while (newcapacity <= newsize) {
178 Reserve(newcapacity);
183 template <
typename T>
185 ContiguousArray<T>::Resize(
size_t newsize, T t) {
186 if (!owned_buffer_) {
187 throw std::runtime_error(
"Cannot resize when using a foreign buffer; clone first");
189 size_t oldsize = Size();
191 for (
size_t i = oldsize; i < newsize; ++i) {
196 template <
typename T>
198 ContiguousArray<T>::Clear() {
199 if (!owned_buffer_) {
200 throw std::runtime_error(
"Cannot clear when using a foreign buffer; clone first");
205 template <
typename T>
207 ContiguousArray<T>::PushBack(T t) {
208 if (!owned_buffer_) {
209 throw std::runtime_error(
"Cannot add element when using a foreign buffer; clone first");
211 if (size_ == capacity_) {
212 Reserve(capacity_ * 2);
214 buffer_[size_++] = t;
217 template <
typename T>
219 ContiguousArray<T>::Extend(
const std::vector<T>& other) {
220 if (!owned_buffer_) {
221 throw std::runtime_error(
"Cannot add elements when using a foreign buffer; clone first");
223 size_t newsize = size_ + other.size();
224 if (newsize > capacity_) {
225 size_t newcapacity = capacity_;
226 if (newcapacity == 0) {
229 while (newcapacity <= newsize) {
232 Reserve(newcapacity);
234 std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()),
sizeof(T) * other.size());
238 template <
typename T>
240 ContiguousArray<T>::operator[](
size_t idx) {
244 template <
typename T>
246 ContiguousArray<T>::operator[](
size_t idx)
const {
250 template<
typename Container>
251 inline std::vector<std::pair<std::string, std::string> >
252 ModelParam::InitAllowUnknown(
const Container& kwargs) {
253 std::vector<std::pair<std::string, std::string>> unknowns;
254 for (
const auto& e : kwargs) {
255 if (e.first ==
"pred_transform") {
256 std::strncpy(this->pred_transform, e.second.c_str(),
257 TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
258 this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] =
'\0';
259 }
else if (e.first ==
"sigmoid_alpha") {
260 this->sigmoid_alpha = dmlc::stof(e.second,
nullptr);
261 }
else if (e.first ==
"global_bias") {
262 this->global_bias = dmlc::stof(e.second,
nullptr);
268 inline std::map<std::string, std::string>
269 ModelParam::__DICT__()
const {
270 std::map<std::string, std::string> ret;
271 ret.emplace(
"pred_transform", std::string(this->pred_transform));
272 ret.emplace(
"sigmoid_alpha", GetString(this->sigmoid_alpha));
273 ret.emplace(
"global_bias", GetString(this->global_bias));
277 inline PyBufferFrame GetPyBufferFromArray(
void* data,
const char* format,
278 size_t itemsize,
size_t nitem) {
279 return PyBufferFrame{data,
const_cast<char*
>(format), itemsize, nitem};
283 template <
typename T>
284 inline const char* InferFormatString() {
287 return (std::is_unsigned<T>::value ?
"=B" :
"=b");
289 return (std::is_unsigned<T>::value ?
"=H" :
"=h");
291 if (std::is_integral<T>::value) {
292 return (std::is_unsigned<T>::value ?
"=L" :
"=l");
294 if (!std::is_floating_point<T>::value) {
295 throw std::runtime_error(
"Could not infer format string");
300 if (std::is_integral<T>::value) {
301 return (std::is_unsigned<T>::value ?
"=Q" :
"=q");
303 if (!std::is_floating_point<T>::value) {
304 throw std::runtime_error(
"Could not infer format string");
309 throw std::runtime_error(
"Unrecognized type");
314 template <
typename T>
315 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec,
const char* format) {
316 return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format,
sizeof(T), vec->Size());
319 template <
typename T>
320 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
321 static_assert(std::is_arithmetic<T>::value,
322 "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
323 return GetPyBufferFromArray(vec, InferFormatString<T>());
326 inline PyBufferFrame GetPyBufferFromScalar(
void* data,
const char* format,
size_t itemsize) {
327 return GetPyBufferFromArray(data, format, itemsize, 1);
330 template <
typename T>
331 inline PyBufferFrame GetPyBufferFromScalar(T* scalar,
const char* format) {
332 return GetPyBufferFromScalar(static_cast<void*>(scalar), format,
sizeof(T));
335 template <
typename T>
336 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
337 static_assert(std::is_arithmetic<T>::value,
338 "Use GetPyBufferFromScalar(scalar, format) for composite types; " 339 "specify format string manually");
340 return GetPyBufferFromScalar(scalar, InferFormatString<T>());
343 template <
typename T>
344 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame buffer) {
345 if (
sizeof(T) != buffer.itemsize) {
346 throw std::runtime_error(
"Incorrect itemsize");
348 vec->UseForeignBuffer(buffer.buf, buffer.nitem);
351 template <
typename T>
352 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
353 if (
sizeof(T) != buffer.itemsize) {
354 throw std::runtime_error(
"Incorrect itemsize");
356 if (buffer.nitem != 1) {
357 throw std::runtime_error(
"nitem must be 1 for a scalar");
359 T* t =
static_cast<T*
>(buffer.buf);
363 constexpr
size_t kNumFramePerTree = 6;
365 inline std::vector<PyBufferFrame>
366 Tree::GetPyBuffer() {
368 GetPyBufferFromScalar(&num_nodes),
369 GetPyBufferFromArray(&nodes_,
"T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}"),
370 GetPyBufferFromArray(&leaf_vector_),
371 GetPyBufferFromArray(&leaf_vector_offset_),
372 GetPyBufferFromArray(&left_categories_),
373 GetPyBufferFromArray(&left_categories_offset_)
378 Tree::InitFromPyBuffer(std::vector<PyBufferFrame> frames) {
380 InitScalarFromPyBuffer(&num_nodes, frames[frame_id++]);
381 InitArrayFromPyBuffer(&nodes_, frames[frame_id++]);
382 if (num_nodes != nodes_.Size()) {
383 throw std::runtime_error(
"Could not load the correct number of nodes");
385 InitArrayFromPyBuffer(&leaf_vector_, frames[frame_id++]);
386 InitArrayFromPyBuffer(&leaf_vector_offset_, frames[frame_id++]);
387 InitArrayFromPyBuffer(&left_categories_, frames[frame_id++]);
388 InitArrayFromPyBuffer(&left_categories_offset_, frames[frame_id++]);
389 if (frame_id != kNumFramePerTree) {
390 throw std::runtime_error(
"Wrong number of frames loaded");
394 inline std::vector<PyBufferFrame>
395 Model::GetPyBuffer() {
397 std::vector<PyBufferFrame> frames{
401 GetPyBufferFromScalar(&
param,
"T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH)
"s=f=f}")
405 for (
auto& tree :
trees) {
406 auto tree_frames = tree.GetPyBuffer();
407 frames.insert(frames.end(), tree_frames.begin(), tree_frames.end());
413 Model::InitFromPyBuffer(std::vector<PyBufferFrame> frames) {
416 InitScalarFromPyBuffer(&
num_feature, frames[frame_id++]);
419 InitScalarFromPyBuffer(&
param, frames[frame_id++]);
421 const size_t num_frame = frames.size();
422 if ((num_frame - frame_id) % kNumFramePerTree != 0) {
423 throw std::runtime_error(
"Wrong number of frames");
426 for (; frame_id < num_frame; frame_id += kNumFramePerTree) {
427 std::vector<PyBufferFrame> tree_frames(frames.begin() + frame_id,
428 frames.begin() + frame_id + kNumFramePerTree);
429 trees.emplace_back();
430 trees.back().InitFromPyBuffer(tree_frames);
435 cleft_ = cright_ = -1;
437 info_.leaf_value = 0.0f;
438 info_.threshold = 0.0f;
440 sum_hess_ = gain_ = 0.0;
441 missing_category_to_zero_ =
false;
442 data_count_present_ = sum_hess_present_ = gain_present_ =
false;
443 split_type_ = SplitFeatureType::kNone;
444 cmp_ = Operator::kNone;
450 int nd = num_nodes++;
451 if (nodes_.Size() !=
static_cast<size_t>(nd)) {
452 throw std::runtime_error(
"Invariant violated: nodes_ contains incorrect number of nodes");
454 for (
int nid = nd; nid < num_nodes; ++nid) {
455 leaf_vector_offset_.PushBack(leaf_vector_offset_.Back());
456 left_categories_offset_.PushBack(left_categories_offset_.Back());
457 nodes_.Resize(nodes_.Size() + 1);
458 nodes_.Back().Init();
464 Tree::Clone()
const {
467 tree.nodes_ = nodes_.Clone();
468 tree.leaf_vector_ = leaf_vector_.Clone();
469 tree.leaf_vector_offset_ = leaf_vector_offset_.Clone();
470 tree.left_categories_ = left_categories_.Clone();
471 tree.left_categories_offset_ = left_categories_offset_.Clone();
478 leaf_vector_.Clear();
479 leaf_vector_offset_.Resize(2, 0);
480 left_categories_.Clear();
481 left_categories_offset_.Resize(2, 0);
489 const int cleft = this->AllocNode();
490 const int cright = this->AllocNode();
491 nodes_[nid].cleft_ = cleft;
492 nodes_[nid].cright_ = cright;
495 inline std::vector<unsigned>
497 std::unordered_map<unsigned, bool> tmp;
498 for (
int nid = 0; nid < num_nodes; ++nid) {
500 if (type != SplitFeatureType::kNone) {
501 const bool flag = (type == SplitFeatureType::kCategorical);
502 const uint32_t split_index = SplitIndex(nid);
503 if (tmp.count(split_index) == 0) {
504 tmp[split_index] = flag;
506 if (tmp[split_index] != flag) {
507 throw std::runtime_error(
"Feature " + std::to_string(split_index) +
508 " cannot be simultaneously be categorical and numerical.");
513 std::vector<unsigned> result;
514 for (
const auto& kv : tmp) {
516 result.push_back(kv.first);
519 std::sort(result.begin(), result.end());
525 return nodes_[nid].cleft_;
530 return nodes_[nid].cright_;
535 return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid);
540 return (nodes_[nid].sindex_ & ((1U << 31U) - 1U));
545 return (nodes_[nid].sindex_ >> 31U) != 0;
550 return nodes_[nid].cleft_ == -1;
555 return (nodes_[nid].info_).leaf_value;
558 inline std::vector<tl_float>
560 if (nid > leaf_vector_offset_.Size()) {
561 throw std::runtime_error(
"nid too large");
563 return std::vector<tl_float>(&leaf_vector_[leaf_vector_offset_[nid]],
564 &leaf_vector_[leaf_vector_offset_[nid + 1]]);
569 if (nid > leaf_vector_offset_.Size()) {
570 throw std::runtime_error(
"nid too large");
572 return leaf_vector_offset_[nid] != leaf_vector_offset_[nid + 1];
577 return (nodes_[nid].info_).threshold;
582 return nodes_[nid].cmp_;
585 inline std::vector<uint32_t>
587 if (nid > left_categories_offset_.Size()) {
588 throw std::runtime_error(
"nid too large");
590 return std::vector<uint32_t>(&left_categories_[left_categories_offset_[nid]],
591 &left_categories_[left_categories_offset_[nid + 1]]);
596 return nodes_[nid].split_type_;
601 return nodes_[nid].data_count_present_;
606 return nodes_[nid].data_count_;
611 return nodes_[nid].sum_hess_present_;
616 return nodes_[nid].sum_hess_;
621 return nodes_[nid].gain_present_;
626 return nodes_[nid].gain_;
631 return nodes_[nid].missing_category_to_zero_;
637 Node& node = nodes_[nid];
638 if (split_index >= ((1U << 31U) - 1)) {
639 throw std::runtime_error(
"split_index too big");
641 if (default_left) split_index |= (1U << 31U);
643 (node.
info_).threshold = threshold;
650 bool missing_category_to_zero,
651 const std::vector<uint32_t>& node_left_categories) {
652 if (split_index >= ((1U << 31U) - 1)) {
653 throw std::runtime_error(
"split_index too big");
656 const size_t end_oft = left_categories_offset_.Back();
657 const size_t new_end_oft = end_oft + node_left_categories.size();
658 if (end_oft != left_categories_.Size()) {
659 throw std::runtime_error(
"Invariant violated");
661 if (!std::all_of(&left_categories_offset_[nid + 1], left_categories_offset_.End(),
662 [end_oft](
size_t x) {
return (x == end_oft); })) {
663 throw std::runtime_error(
"Invariant violated");
666 left_categories_.Extend(node_left_categories);
667 if (new_end_oft != left_categories_.Size()) {
668 throw std::runtime_error(
"Invariant violated");
670 std::for_each(&left_categories_offset_[nid + 1], left_categories_offset_.End(),
671 [new_end_oft](
size_t& x) { x = new_end_oft; });
672 std::sort(&left_categories_[end_oft], left_categories_.End());
674 Node& node = nodes_[nid];
675 if (default_left) split_index |= (1U << 31U);
677 node.split_type_ = SplitFeatureType::kCategorical;
678 node.missing_category_to_zero_ = missing_category_to_zero;
683 Node& node = nodes_[nid];
684 (node.
info_).leaf_value = value;
692 const size_t end_oft = leaf_vector_offset_.Back();
693 const size_t new_end_oft = end_oft + node_leaf_vector.size();
694 if (end_oft != leaf_vector_.Size()) {
695 throw std::runtime_error(
"Invariant violated");
697 if (!std::all_of(&leaf_vector_offset_[nid + 1], leaf_vector_offset_.End(),
698 [end_oft](
size_t x) {
return (x == end_oft); })) {
699 throw std::runtime_error(
"Invariant violated");
702 leaf_vector_.Extend(node_leaf_vector);
703 if (new_end_oft != leaf_vector_.Size()) {
704 throw std::runtime_error(
"Invariant violated");
706 std::for_each(&leaf_vector_offset_[nid + 1], leaf_vector_offset_.End(),
707 [new_end_oft](
size_t& x) { x = new_end_oft; });
709 Node& node = nodes_[nid];
712 node.split_type_ = SplitFeatureType::kNone;
717 Node& node = nodes_[nid];
724 Node& node = nodes_[nid];
731 Node& node = nodes_[nid];
737 Model::Clone()
const {
739 for (
const Tree& t : trees) {
740 model.
trees.push_back(t.Clone());
750 const std::vector<std::pair<std::string, std::string>>& cfg) {
751 auto unknown = param->InitAllowUnknown(cfg);
752 if (!unknown.empty()) {
753 std::ostringstream oss;
754 for (
const auto& kv : unknown) {
755 oss << kv.first <<
", ";
757 std::cerr <<
"\033[1;31mWarning: Unknown parameters found; " 758 <<
"they have been ignored\u001B[0m: " << oss.str() << std::endl;
763 #endif // TREELITE_TREE_IMPL_H_ SplitFeatureType split_type_
feature split type
Operator ComparisonOp(int nid) const
get comparison operator
bool gain_present_
whether gain_present_ field is present
int num_output_group
number of output groups – for multi-class classification Set to 1 for everything else ...
SplitFeatureType
feature split type
void Init()
initialize the model with a single root node
uint64_t data_count_
number of data points whose traversal paths include this node. LightGBM models natively store this st...
thin wrapper for tree ensemble model
float tl_float
float type to be used internally
bool HasDataCount(int nid) const
test whether this node has data count
std::vector< tl_float > LeafVector(int nid) const
get leaf vector of the leaf node; useful for multi-class random forest classifier ...
bool HasGain(int nid) const
test whether this node has gain value
Operator cmp_
operator to use for expression of form [fval] OP [threshold]. If the expression evaluates to true...
std::vector< Tree > trees
member trees
tl_float Threshold(int nid) const
get threshold of the node
int DefaultChild(int nid) const
index of the node's "default" child, used when feature is missing
ModelParam param
extra parameters
bool data_count_present_
whether data_count_ field is present
int32_t cleft_
pointer to left and right children
std::vector< uint32_t > LeftCategories(int nid) const
Get list of all categories belonging to the left child node. Categories not in this list will belong ...
void SetSumHess(int nid, double sum_hess)
set the hessian sum of the node
in-memory representation of a decision tree
double sum_hess_
sum of hessian values for all data points whose traversal paths include this node. This value is generally correlated positively with the data count. XGBoost models natively store this statistics.
double gain_
change in loss that is attributed to a particular split
uint32_t SplitIndex(int nid) const
feature index of the node's split condition
std::vector< unsigned > GetCategoricalFeatures() const
get list of all categorical features that have appeared anywhere in tree
bool MissingCategoryToZero(int nid) const
test whether missing values should be converted into zero; only applicable for categorical splits ...
double SumHess(int nid) const
get hessian sum
void SetLeafVector(int nid, const std::vector< tl_float > &leaf_vector)
set the leaf vector of the node; useful for multi-class random forest classifier
void SetCategoricalSplit(int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, const std::vector< uint32_t > &left_categories)
create a categorical split
void SetGain(int nid, double gain)
set the gain value of the node
bool random_forest_flag
flag for random forest; True for random forests and False for gradient boosted trees ...
void SetDataCount(int nid, uint64_t data_count)
set the data count of the node
void Init()
Initialization method. Use this in lieu of constructor (POD types cannot have a non-trivial construct...
tl_float LeafValue(int nid) const
get leaf value of the leaf node
int num_nodes
number of nodes
SplitFeatureType SplitType(int nid) const
get feature split type
void SetLeaf(int nid, tl_float value)
set the leaf value of the node
void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp)
Setters.
int LeftChild(int nid) const
Getters.
uint64_t DataCount(int nid) const
get data count
double Gain(int nid) const
get gain value
int RightChild(int nid) const
index of the node's right child
void AddChilds(int nid)
add child nodes to node
bool sum_hess_present_
whether sum_hess_ field is present
bool DefaultLeft(int nid) const
whether to use the left child node, when the feature in the split condition is missing ...
bool HasSumHess(int nid) const
test whether this node has hessian sum
bool IsLeaf(int nid) const
whether the node is leaf node
uint32_t sindex_
feature index used for the split highest bit indicates default direction for missing values ...
bool HasLeafVector(int nid) const
tests whether the leaf node has a non-empty leaf vector
int num_feature
number of features used for the model. It is assumed that all feature indices are between 0 and [num_...
Info info_
storage for leaf value or decision threshold
Operator
comparison operators