7 #ifndef TREELITE_TREE_IMPL_H_ 8 #define TREELITE_TREE_IMPL_H_ 11 #include <treelite/version.h> 20 #include <unordered_map> 31 inline std::string GetString(T x) {
32 return std::to_string(x);
36 inline std::string GetString<float>(
float x) {
37 std::ostringstream oss;
38 oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
43 inline std::string GetString<double>(
double x) {
44 std::ostringstream oss;
45 oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
54 ContiguousArray<T>::ContiguousArray()
55 : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
58 ContiguousArray<T>::~ContiguousArray() {
59 if (buffer_ && owned_buffer_) {
65 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
66 : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
67 owned_buffer_(other.owned_buffer_) {
68 other.buffer_ =
nullptr;
69 other.size_ = other.capacity_ = 0;
74 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
75 if (buffer_ && owned_buffer_) {
78 buffer_ = other.buffer_;
80 capacity_ = other.capacity_;
81 owned_buffer_ = other.owned_buffer_;
82 other.buffer_ =
nullptr;
83 other.size_ = other.capacity_ = 0;
88 inline ContiguousArray<T>
89 ContiguousArray<T>::Clone()
const {
90 ContiguousArray clone;
91 clone.buffer_ =
static_cast<T*
>(std::malloc(
sizeof(T) * capacity_));
93 throw Error(
"Could not allocate memory for the clone");
95 std::memcpy(clone.buffer_, buffer_,
sizeof(T) * size_);
97 clone.capacity_ = capacity_;
98 clone.owned_buffer_ =
true;
102 template <
typename T>
104 ContiguousArray<T>::UseForeignBuffer(
void* prealloc_buf, std::size_t size) {
105 if (buffer_ && owned_buffer_) {
108 buffer_ =
static_cast<T*
>(prealloc_buf);
111 owned_buffer_ =
false;
114 template <
typename T>
116 ContiguousArray<T>::Data() {
120 template <
typename T>
122 ContiguousArray<T>::Data()
const {
126 template <
typename T>
128 ContiguousArray<T>::End() {
129 return &buffer_[Size()];
132 template <
typename T>
134 ContiguousArray<T>::End()
const {
135 return &buffer_[Size()];
138 template <
typename T>
140 ContiguousArray<T>::Back() {
141 return buffer_[Size() - 1];
144 template <
typename T>
146 ContiguousArray<T>::Back()
const {
147 return buffer_[Size() - 1];
150 template <
typename T>
152 ContiguousArray<T>::Size()
const {
156 template <
typename T>
158 ContiguousArray<T>::Empty()
const {
159 return (Size() == 0);
162 template <
typename T>
164 ContiguousArray<T>::Reserve(std::size_t newsize) {
165 if (!owned_buffer_) {
166 throw Error(
"Cannot resize when using a foreign buffer; clone first");
168 T* newbuf =
static_cast<T*
>(std::realloc(static_cast<void*>(buffer_),
sizeof(T) * newsize));
170 throw Error(
"Could not expand buffer");
176 template <
typename T>
178 ContiguousArray<T>::Resize(std::size_t newsize) {
179 if (!owned_buffer_) {
180 throw Error(
"Cannot resize when using a foreign buffer; clone first");
182 if (newsize > capacity_) {
183 std::size_t newcapacity = capacity_;
184 if (newcapacity == 0) {
187 while (newcapacity <= newsize) {
190 Reserve(newcapacity);
195 template <
typename T>
197 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
198 if (!owned_buffer_) {
199 throw Error(
"Cannot resize when using a foreign buffer; clone first");
201 std::size_t oldsize = Size();
203 for (std::size_t i = oldsize; i < newsize; ++i) {
208 template <
typename T>
210 ContiguousArray<T>::Clear() {
211 if (!owned_buffer_) {
212 throw Error(
"Cannot clear when using a foreign buffer; clone first");
217 template <
typename T>
219 ContiguousArray<T>::PushBack(T t) {
220 if (!owned_buffer_) {
221 throw Error(
"Cannot add element when using a foreign buffer; clone first");
223 if (size_ == capacity_) {
224 Reserve(capacity_ * 2);
226 buffer_[size_++] = t;
229 template <
typename T>
231 ContiguousArray<T>::Extend(
const std::vector<T>& other) {
232 if (!owned_buffer_) {
233 throw Error(
"Cannot add elements when using a foreign buffer; clone first");
238 std::size_t newsize = size_ + other.size();
239 if (newsize > capacity_) {
240 std::size_t newcapacity = capacity_;
241 if (newcapacity == 0) {
244 while (newcapacity <= newsize) {
247 Reserve(newcapacity);
249 std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()),
sizeof(T) * other.size());
253 template <
typename T>
255 ContiguousArray<T>::operator[](std::size_t idx) {
259 template <
typename T>
261 ContiguousArray<T>::operator[](std::size_t idx)
const {
265 template <
typename T>
267 ContiguousArray<T>::at(std::size_t idx) {
269 throw Error(
"nid out of range");
274 template <
typename T>
276 ContiguousArray<T>::at(std::size_t idx)
const {
278 throw Error(
"nid out of range");
283 template <
typename T>
285 ContiguousArray<T>::at(
int idx) {
286 if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
287 throw Error(
"nid out of range");
289 return buffer_[
static_cast<std::size_t
>(idx)];
292 template <
typename T>
294 ContiguousArray<T>::at(
int idx)
const {
295 if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
296 throw Error(
"nid out of range");
298 return buffer_[
static_cast<std::size_t
>(idx)];
301 template<
typename Container>
302 inline std::vector<std::pair<std::string, std::string> >
303 ModelParam::InitAllowUnknown(
const Container& kwargs) {
304 std::vector<std::pair<std::string, std::string>> unknowns;
305 for (
const auto& e : kwargs) {
306 if (e.first ==
"pred_transform") {
307 std::strncpy(this->pred_transform, e.second.c_str(),
308 TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
309 this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] =
'\0';
310 }
else if (e.first ==
"sigmoid_alpha") {
311 this->sigmoid_alpha = std::stof(e.second,
nullptr);
312 }
else if (e.first ==
"ratio_c") {
313 this->ratio_c = std::stof(e.second,
nullptr);
314 }
else if (e.first ==
"global_bias") {
315 this->global_bias = std::stof(e.second,
nullptr);
321 inline std::map<std::string, std::string>
322 ModelParam::__DICT__()
const {
323 std::map<std::string, std::string> ret;
324 ret.emplace(
"pred_transform", std::string(this->pred_transform));
325 ret.emplace(
"sigmoid_alpha", GetString(this->sigmoid_alpha));
326 ret.emplace(
"ratio_c", GetString(this->ratio_c));
327 ret.emplace(
"global_bias", GetString(this->global_bias));
331 inline PyBufferFrame GetPyBufferFromArray(
void* data,
const char* format,
332 std::size_t itemsize, std::size_t nitem) {
333 return PyBufferFrame{data,
const_cast<char*
>(format), itemsize, nitem};
337 template <
typename T>
338 inline const char* InferFormatString() {
341 return (std::is_unsigned<T>::value ?
"=B" :
"=b");
343 return (std::is_unsigned<T>::value ?
"=H" :
"=h");
345 if (std::is_integral<T>::value) {
346 return (std::is_unsigned<T>::value ?
"=L" :
"=l");
348 if (!std::is_floating_point<T>::value) {
349 throw Error(
"Could not infer format string");
354 if (std::is_integral<T>::value) {
355 return (std::is_unsigned<T>::value ?
"=Q" :
"=q");
357 if (!std::is_floating_point<T>::value) {
358 throw Error(
"Could not infer format string");
363 throw Error(
"Unrecognized type");
368 template <
typename T>
369 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec,
const char* format) {
370 return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format,
sizeof(T), vec->Size());
373 template <
typename T>
374 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
375 static_assert(std::is_arithmetic<T>::value,
376 "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
377 return GetPyBufferFromArray(vec, InferFormatString<T>());
380 inline PyBufferFrame GetPyBufferFromScalar(
void* data,
const char* format, std::size_t itemsize) {
381 return GetPyBufferFromArray(data, format, itemsize, 1);
384 template <
typename T>
385 inline PyBufferFrame GetPyBufferFromScalar(T* scalar,
const char* format) {
386 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
387 return GetPyBufferFromScalar(static_cast<void*>(scalar), format,
sizeof(T));
390 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
391 using T = std::underlying_type<TypeInfo>::type;
392 return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
395 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
396 using T = std::underlying_type<TaskType>::type;
397 return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
400 template <
typename T>
401 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
402 static_assert(std::is_arithmetic<T>::value,
403 "Use GetPyBufferFromScalar(scalar, format) for composite types; " 404 "specify format string manually");
405 return GetPyBufferFromScalar(scalar, InferFormatString<T>());
408 template <
typename T>
409 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
410 if (
sizeof(T) != frame.itemsize) {
411 throw Error(
"Incorrect itemsize");
413 vec->UseForeignBuffer(frame.buf, frame.nitem);
416 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
417 using T = std::underlying_type<TypeInfo>::type;
418 if (
sizeof(T) != buffer.itemsize) {
419 throw Error(
"Incorrect itemsize");
421 if (buffer.nitem != 1) {
422 throw Error(
"nitem must be 1 for a scalar");
424 T* t =
static_cast<T*
>(buffer.buf);
425 *scalar =
static_cast<TypeInfo>(*t);
428 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
429 using T = std::underlying_type<TaskType>::type;
430 if (
sizeof(T) != buffer.itemsize) {
431 throw Error(
"Incorrect itemsize");
433 if (buffer.nitem != 1) {
434 throw Error(
"nitem must be 1 for a scalar");
436 T* t =
static_cast<T*
>(buffer.buf);
437 *scalar =
static_cast<TaskType>(*t);
440 template <
typename T>
441 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
442 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
443 if (
sizeof(T) != buffer.itemsize) {
444 throw Error(
"Incorrect itemsize");
446 if (buffer.nitem != 1) {
447 throw Error(
"nitem must be 1 for a scalar");
449 T* t =
static_cast<T*
>(buffer.buf);
453 template <
typename T>
454 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
455 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
456 if (std::fread(scalar,
sizeof(T), 1, fp) < 1) {
457 throw Error(
"Could not read a scalar");
461 template <
typename T>
462 inline void WriteScalarToFile(T* scalar, FILE* fp) {
463 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
464 if (std::fwrite(scalar,
sizeof(T), 1, fp) < 1) {
465 throw Error(
"Could not write a scalar");
469 template <
typename T>
470 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
472 if (std::fread(&nelem,
sizeof(nelem), 1, fp) < 1) {
473 throw Error(
"Could not read the number of elements");
480 const auto nelem_size_t =
static_cast<std::size_t
>(nelem);
481 if (std::fread(vec->Data(),
sizeof(T), nelem_size_t, fp) < nelem_size_t) {
482 throw Error(
"Could not read an array");
486 template <
typename T>
487 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
488 static_assert(
sizeof(uint64_t) >=
sizeof(
size_t),
"size_t too large");
489 const auto nelem =
static_cast<uint64_t
>(vec->Size());
490 if (std::fwrite(&nelem,
sizeof(nelem), 1, fp) < 1) {
491 throw Error(
"Could not write the number of elements");
496 const auto nelem_size_t = vec->Size();
497 if (std::fwrite(vec->Data(),
sizeof(T), nelem_size_t, fp) < nelem_size_t) {
498 throw Error(
"Could not write an array");
502 inline void SkipOptFieldInFile(FILE* fp) {
505 ReadScalarFromFile(&elem_size, fp);
506 ReadScalarFromFile(&nelem, fp);
508 const uint64_t nbytes = elem_size * nelem;
509 TREELITE_CHECK_LE(nbytes, std::numeric_limits<long>::max());
510 if (std::fseek(fp, static_cast<long>(nbytes), SEEK_CUR) != 0) {
511 throw Error(
"Reached end of file");
515 template <
typename ThresholdType,
typename LeafOutputType>
516 Tree<ThresholdType, LeafOutputType>::Tree(
bool use_opt_field)
517 : use_opt_field_(use_opt_field)
520 template <
typename ThresholdType,
typename LeafOutputType>
521 inline Tree<ThresholdType, LeafOutputType>
522 Tree<ThresholdType, LeafOutputType>::Clone()
const {
523 Tree<ThresholdType, LeafOutputType> tree;
524 tree.num_nodes = num_nodes;
525 tree.nodes_ = nodes_.Clone();
526 tree.leaf_vector_ = leaf_vector_.Clone();
527 tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
528 tree.leaf_vector_end_ = leaf_vector_end_.Clone();
529 tree.matching_categories_ = matching_categories_.Clone();
530 tree.matching_categories_offset_ = matching_categories_offset_.Clone();
534 template <
typename ThresholdType,
typename LeafOutputType>
536 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
537 if (std::is_same<ThresholdType, float>::value) {
538 return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
540 return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
544 template <
typename ThresholdType,
typename LeafOutputType>
545 template <
typename ScalarHandler,
typename PrimitiveArrayHandler,
typename CompositeArrayHandler>
547 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
548 ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
549 CompositeArrayHandler composite_array_handler) {
550 scalar_handler(&num_nodes);
551 scalar_handler(&has_categorical_split_);
552 composite_array_handler(&nodes_, GetFormatStringForNode());
553 primitive_array_handler(&leaf_vector_);
554 primitive_array_handler(&leaf_vector_begin_);
555 primitive_array_handler(&leaf_vector_end_);
556 primitive_array_handler(&matching_categories_);
557 primitive_array_handler(&matching_categories_offset_);
560 num_opt_field_per_tree_ = 0;
561 scalar_handler(&num_opt_field_per_tree_);
564 num_opt_field_per_node_ = 0;
565 scalar_handler(&num_opt_field_per_node_);
568 template <
typename ThresholdType,
typename LeafOutputType>
569 template <
typename ScalarHandler,
typename ArrayHandler,
typename SkipOptFieldHandlerFunc>
571 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
572 ScalarHandler scalar_handler,
573 ArrayHandler array_handler,
574 SkipOptFieldHandlerFunc skip_opt_field_handler) {
575 scalar_handler(&num_nodes);
576 scalar_handler(&has_categorical_split_);
577 array_handler(&nodes_);
578 if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
579 throw Error(
"Could not load the correct number of nodes");
581 array_handler(&leaf_vector_);
582 array_handler(&leaf_vector_begin_);
583 array_handler(&leaf_vector_end_);
584 array_handler(&matching_categories_);
585 array_handler(&matching_categories_offset_);
587 if (use_opt_field_) {
589 scalar_handler(&num_opt_field_per_tree_);
591 for (int32_t i = 0; i < num_opt_field_per_tree_; ++i) {
592 skip_opt_field_handler();
596 scalar_handler(&num_opt_field_per_node_);
598 for (int32_t i = 0; i < num_opt_field_per_node_; ++i) {
599 skip_opt_field_handler();
603 num_opt_field_per_tree_ = 0;
604 num_opt_field_per_node_ = 0;
608 template <
typename ThresholdType,
typename LeafOutputType>
610 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
611 auto scalar_handler = [dest](
auto* field) {
612 dest->push_back(GetPyBufferFromScalar(field));
614 auto primitive_array_handler = [dest](
auto* field) {
615 dest->push_back(GetPyBufferFromArray(field));
617 auto composite_array_handler = [dest](
auto* field,
const char* format) {
618 dest->push_back(GetPyBufferFromArray(field, format));
620 SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
623 template <
typename ThresholdType,
typename LeafOutputType>
625 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
626 auto scalar_handler = [dest_fp](
auto* field) {
627 WriteScalarToFile(field, dest_fp);
629 auto primitive_array_handler = [dest_fp](
auto* field) {
630 WriteArrayToFile(field, dest_fp);
632 auto composite_array_handler = [dest_fp](
auto* field,
const char* format) {
633 WriteArrayToFile(field, dest_fp);
635 SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
638 template <
typename ThresholdType,
typename LeafOutputType>
639 inline std::vector<PyBufferFrame>::iterator
640 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it) {
641 std::vector<PyBufferFrame>::iterator new_it = it;
642 auto scalar_handler = [&new_it](
auto* field) {
643 InitScalarFromPyBuffer(field, *(new_it++));
645 auto array_handler = [&new_it](
auto* field) {
646 InitArrayFromPyBuffer(field, *(new_it++));
648 auto skip_opt_field_handler = [&new_it]() {
651 DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
655 template <
typename ThresholdType,
typename LeafOutputType>
657 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
658 auto scalar_handler = [src_fp](
auto* field) {
659 ReadScalarFromFile(field, src_fp);
661 auto array_handler = [src_fp](
auto* field) {
662 ReadArrayFromFile(field, src_fp);
664 auto skip_opt_field_handler = [src_fp]() {
665 SkipOptFieldInFile(src_fp);
667 DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
670 template <
typename ThresholdType,
typename LeafOutputType>
672 std::memset(
this, 0,
sizeof(
Node));
673 cleft_ = cright_ = -1;
675 info_.leaf_value =
static_cast<LeafOutputType
>(0);
676 info_.threshold =
static_cast<ThresholdType
>(0);
678 sum_hess_ = gain_ = 0.0;
679 data_count_present_ = sum_hess_present_ = gain_present_ =
false;
680 categories_list_right_child_ =
false;
681 split_type_ = SplitFeatureType::kNone;
682 cmp_ = Operator::kNone;
685 template <
typename ThresholdType,
typename LeafOutputType>
688 int nd = num_nodes++;
689 if (nodes_.Size() !=
static_cast<std::size_t
>(nd)) {
690 throw Error(
"Invariant violated: nodes_ contains incorrect number of nodes");
692 for (
int nid = nd; nid < num_nodes; ++nid) {
693 leaf_vector_begin_.PushBack(0);
694 leaf_vector_end_.PushBack(0);
695 matching_categories_offset_.PushBack(matching_categories_offset_.Back());
696 nodes_.Resize(nodes_.Size() + 1);
697 nodes_.Back().Init();
702 template <
typename ThresholdType,
typename LeafOutputType>
706 has_categorical_split_ =
false;
707 leaf_vector_.Clear();
708 leaf_vector_begin_.Resize(1, {});
709 leaf_vector_end_.Resize(1, {});
710 matching_categories_.Clear();
711 matching_categories_offset_.Resize(2, 0);
714 SetLeaf(0, static_cast<LeafOutputType>(0));
717 template <
typename ThresholdType,
typename LeafOutputType>
720 const int cleft = this->AllocNode();
721 const int cright = this->AllocNode();
722 nodes_.at(nid).cleft_ = cleft;
723 nodes_.at(nid).cright_ = cright;
726 template <
typename ThresholdType,
typename LeafOutputType>
729 int nid,
unsigned split_index, ThresholdType threshold,
bool default_left,
Operator cmp) {
730 Node& node = nodes_.at(nid);
731 if (split_index >= ((1U << 31U) - 1)) {
732 throw Error(
"split_index too big");
734 if (default_left) split_index |= (1U << 31U);
735 node.sindex_ = split_index;
736 (node.info_).threshold = threshold;
738 node.split_type_ = SplitFeatureType::kNumerical;
739 node.categories_list_right_child_ =
false;
742 template <
typename ThresholdType,
typename LeafOutputType>
745 int nid,
unsigned split_index,
bool default_left,
746 const std::vector<uint32_t>& categories_list,
bool categories_list_right_child) {
747 if (split_index >= ((1U << 31U) - 1)) {
748 throw Error(
"split_index too big");
751 const std::size_t end_oft = matching_categories_offset_.Back();
752 const std::size_t new_end_oft = end_oft + categories_list.size();
753 if (end_oft != matching_categories_.Size()) {
754 throw Error(
"Invariant violated");
756 if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
757 [end_oft](std::size_t x) {
return (x == end_oft); })) {
758 throw Error(
"Invariant violated");
761 matching_categories_.Extend(categories_list);
762 if (new_end_oft != matching_categories_.Size()) {
763 throw Error(
"Invariant violated");
765 std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
766 [new_end_oft](std::size_t& x) { x = new_end_oft; });
767 if (!matching_categories_.Empty()) {
768 std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
771 Node& node = nodes_.at(nid);
772 if (default_left) split_index |= (1U << 31U);
773 node.sindex_ = split_index;
774 node.split_type_ = SplitFeatureType::kCategorical;
775 node.categories_list_right_child_ = categories_list_right_child;
777 has_categorical_split_ =
true;
780 template <
typename ThresholdType,
typename LeafOutputType>
783 Node& node = nodes_.at(nid);
784 (node.info_).leaf_value = value;
787 node.split_type_ = SplitFeatureType::kNone;
790 template <
typename ThresholdType,
typename LeafOutputType>
793 int nid,
const std::vector<LeafOutputType>& node_leaf_vector) {
794 std::size_t begin = leaf_vector_.Size();
795 std::size_t end = begin + node_leaf_vector.size();
796 leaf_vector_.Extend(node_leaf_vector);
797 leaf_vector_begin_[nid] = begin;
798 leaf_vector_end_[nid] = end;
799 Node &node = nodes_.at(nid);
802 node.split_type_ = SplitFeatureType::kNone;
805 template <
typename ThresholdType,
typename LeafOutputType>
806 inline std::unique_ptr<Model>
808 std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
809 model->threshold_type_ = TypeToInfo<ThresholdType>();
810 model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
814 template <
typename ThresholdType,
typename LeafOutputType>
817 inline static std::unique_ptr<Model> Dispatch() {
818 return Model::Create<ThresholdType, LeafOutputType>();
822 inline std::unique_ptr<Model>
824 return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
827 template <
typename ThresholdType,
typename LeafOutputType>
830 template <
typename Func>
831 inline static auto Dispatch(
Model* model, Func func) {
835 template <
typename Func>
836 inline static auto Dispatch(
const Model* model, Func func) {
841 template <
typename Func>
843 Model::Dispatch(Func func) {
844 return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_,
this, func);
847 template <
typename Func>
849 Model::Dispatch(Func func)
const {
850 return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_,
this, func);
853 template <
typename HeaderPrimitiveFieldHandlerFunc>
855 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
856 major_ver_ = TREELITE_VER_MAJOR;
857 minor_ver_ = TREELITE_VER_MINOR;
858 patch_ver_ = TREELITE_VER_PATCH;
859 header_primitive_field_handler(&major_ver_);
860 header_primitive_field_handler(&minor_ver_);
861 header_primitive_field_handler(&patch_ver_);
862 header_primitive_field_handler(&threshold_type_);
863 header_primitive_field_handler(&leaf_output_type_);
866 template <
typename HeaderPrimitiveFieldHandlerFunc>
868 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
869 int32_t& major_ver, int32_t& minor_ver, int32_t& patch_ver,
871 header_primitive_field_handler(&major_ver);
872 header_primitive_field_handler(&minor_ver);
873 header_primitive_field_handler(&patch_ver);
874 if (major_ver != TREELITE_VER_MAJOR && !(major_ver == 2 && minor_ver == 4)) {
875 std::ostringstream oss;
876 oss <<
"Cannot deserialize model from a different major Treelite version or " 877 <<
"a version before 2.4.0." << std::endl
878 <<
"Currently running Treelite version " << TREELITE_VER_MAJOR <<
"." 879 << TREELITE_VER_MINOR <<
"." << TREELITE_VER_PATCH << std::endl
880 <<
"The model checkpoint was generated from Treelite version " << major_ver <<
"." 881 << minor_ver <<
"." << patch_ver;
882 throw Error(oss.str());
884 header_primitive_field_handler(&threshold_type);
885 header_primitive_field_handler(&leaf_output_type);
888 template <
typename ThresholdType,
typename LeafOutputType>
889 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
890 typename TreeHandlerFunc>
893 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
894 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
895 TreeHandlerFunc tree_handler) {
897 header_primitive_field_handler(&num_feature);
898 header_primitive_field_handler(&task_type);
899 header_primitive_field_handler(&average_tree_output);
900 header_composite_field_handler(&task_param,
"T{=B=?xx=I=I}");
901 header_composite_field_handler(
902 ¶m,
"T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH)
"s=f=f=f}");
905 num_opt_field_per_model_ = 0;
906 header_primitive_field_handler(&num_opt_field_per_model_);
914 template <
typename ThresholdType,
typename LeafOutputType>
915 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc,
916 typename SkipOptFieldHandlerFunc>
919 std::size_t num_tree,
920 HeaderFieldHandlerFunc header_field_handler,
921 TreeHandlerFunc tree_handler,
922 SkipOptFieldHandlerFunc skip_opt_field_handler) {
924 header_field_handler(&num_feature);
925 header_field_handler(&task_type);
926 header_field_handler(&average_tree_output);
927 header_field_handler(&task_param);
928 header_field_handler(¶m);
931 const bool use_opt_field = (major_ver_ >= 3);
933 header_field_handler(&num_opt_field_per_model_);
935 for (int32_t i = 0; i < num_opt_field_per_model_; ++i) {
936 skip_opt_field_handler();
940 num_opt_field_per_model_ = 0;
945 for (std::size_t i = 0; i < num_tree; ++i) {
946 trees.emplace_back(use_opt_field);
947 tree_handler(trees.back());
951 template <
typename ThresholdType,
typename LeafOutputType>
954 num_tree_ =
static_cast<uint64_t
>(this->trees.size());
955 auto header_primitive_field_handler = [dest](
auto* field) {
956 dest->push_back(GetPyBufferFromScalar(field));
958 auto header_composite_field_handler = [dest](
auto* field,
const char* format) {
959 dest->push_back(GetPyBufferFromScalar(field, format));
962 tree.GetPyBuffer(dest);
964 header_primitive_field_handler(&num_tree_);
965 SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
968 template <
typename ThresholdType,
typename LeafOutputType>
971 num_tree_ =
static_cast<uint64_t
>(this->trees.size());
972 auto header_primitive_field_handler = [dest_fp](
auto* field) {
973 WriteScalarToFile(field, dest_fp);
975 auto header_composite_field_handler = [dest_fp](
auto* field,
const char* format) {
976 WriteScalarToFile(field, dest_fp);
979 tree.SerializeToFile(dest_fp);
981 header_primitive_field_handler(&num_tree_);
982 SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
985 template <
typename ThresholdType,
typename LeafOutputType>
986 inline std::vector<PyBufferFrame>::iterator
988 std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) {
989 std::vector<PyBufferFrame>::iterator new_it = it;
990 auto header_field_handler = [&new_it](
auto* field) {
991 InitScalarFromPyBuffer(field, *(new_it++));
994 auto skip_opt_field_handler = [&new_it]() {
999 new_it = tree.InitFromPyBuffer(new_it);
1002 if (major_ver_ == 2) {
1005 constexpr std::size_t kNumFrameInHeader = 5;
1006 constexpr std::size_t kNumFramePerTree = 8;
1007 num_tree_ = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
1011 header_field_handler(&num_tree_);
1014 DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1015 TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1020 template <
typename ThresholdType,
typename LeafOutputType>
1023 ReadScalarFromFile(&num_tree_, src_fp);
1025 auto header_field_handler = [src_fp](
auto* field) {
1026 ReadScalarFromFile(field, src_fp);
1029 auto skip_opt_field_handler = [src_fp]() {
1030 SkipOptFieldInFile(src_fp);
1034 tree.DeserializeFromFile(src_fp);
1037 DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1038 TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1041 inline void InitParamAndCheck(
ModelParam* param,
1042 const std::vector<std::pair<std::string, std::string>>& cfg) {
1043 auto unknown = param->InitAllowUnknown(cfg);
1044 if (!unknown.empty()) {
1045 std::ostringstream oss;
1046 for (
const auto& kv : unknown) {
1047 oss << kv.first <<
", ";
1049 std::cerr <<
"\033[1;31mWarning: Unknown parameters found; " 1050 <<
"they have been ignored\u001B[0m: " << oss.str() << std::endl;
1055 #endif // TREELITE_TREE_IMPL_H_ void Init()
initialize the model with a single root node
Exception class that will be thrown by Treelite.
TaskType
Enum type representing the task type.
in-memory representation of a decision tree
logging facility for Treelite
Exception class used throughout the Treelite codebase.
TypeInfo
Types used by thresholds and leaf outputs.
thin wrapper for tree ensemble model
Operator
comparison operators