7 #ifndef TREELITE_TREE_IMPL_H_ 8 #define TREELITE_TREE_IMPL_H_ 11 #include <treelite/version.h> 20 #include <unordered_map> 31 inline std::string GetString(T x) {
32 return std::to_string(x);
36 inline std::string GetString<float>(
float x) {
37 std::ostringstream oss;
38 oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
43 inline std::string GetString<double>(
double x) {
44 std::ostringstream oss;
45 oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
54 ContiguousArray<T>::ContiguousArray()
55 : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
58 ContiguousArray<T>::~ContiguousArray() {
59 if (buffer_ && owned_buffer_) {
65 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
66 : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
67 owned_buffer_(other.owned_buffer_) {
68 other.buffer_ =
nullptr;
69 other.size_ = other.capacity_ = 0;
74 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
75 if (buffer_ && owned_buffer_) {
78 buffer_ = other.buffer_;
80 capacity_ = other.capacity_;
81 owned_buffer_ = other.owned_buffer_;
82 other.buffer_ =
nullptr;
83 other.size_ = other.capacity_ = 0;
88 inline ContiguousArray<T>
89 ContiguousArray<T>::Clone()
const {
90 ContiguousArray clone;
92 clone.buffer_ =
static_cast<T*
>(std::malloc(
sizeof(T) * capacity_));
94 throw Error(
"Could not allocate memory for the clone");
96 std::memcpy(clone.buffer_, buffer_,
sizeof(T) * size_);
98 TREELITE_CHECK_EQ(size_, 0);
99 TREELITE_CHECK_EQ(capacity_, 0);
100 clone.buffer_ =
nullptr;
103 clone.capacity_ = capacity_;
104 clone.owned_buffer_ =
true;
108 template <
typename T>
110 ContiguousArray<T>::UseForeignBuffer(
void* prealloc_buf, std::size_t size) {
111 if (buffer_ && owned_buffer_) {
114 buffer_ =
static_cast<T*
>(prealloc_buf);
117 owned_buffer_ =
false;
120 template <
typename T>
122 ContiguousArray<T>::Data() {
126 template <
typename T>
128 ContiguousArray<T>::Data()
const {
132 template <
typename T>
134 ContiguousArray<T>::End() {
135 return &buffer_[Size()];
138 template <
typename T>
140 ContiguousArray<T>::End()
const {
141 return &buffer_[Size()];
144 template <
typename T>
146 ContiguousArray<T>::Back() {
147 return buffer_[Size() - 1];
150 template <
typename T>
152 ContiguousArray<T>::Back()
const {
153 return buffer_[Size() - 1];
156 template <
typename T>
158 ContiguousArray<T>::Size()
const {
162 template <
typename T>
164 ContiguousArray<T>::Empty()
const {
165 return (Size() == 0);
168 template <
typename T>
170 ContiguousArray<T>::Reserve(std::size_t newsize) {
171 if (!owned_buffer_) {
172 throw Error(
"Cannot resize when using a foreign buffer; clone first");
174 T* newbuf =
static_cast<T*
>(std::realloc(static_cast<void*>(buffer_),
sizeof(T) * newsize));
176 throw Error(
"Could not expand buffer");
182 template <
typename T>
184 ContiguousArray<T>::Resize(std::size_t newsize) {
185 if (!owned_buffer_) {
186 throw Error(
"Cannot resize when using a foreign buffer; clone first");
188 if (newsize > capacity_) {
189 std::size_t newcapacity = capacity_;
190 if (newcapacity == 0) {
193 while (newcapacity <= newsize) {
196 Reserve(newcapacity);
201 template <
typename T>
203 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
204 if (!owned_buffer_) {
205 throw Error(
"Cannot resize when using a foreign buffer; clone first");
207 std::size_t oldsize = Size();
209 for (std::size_t i = oldsize; i < newsize; ++i) {
214 template <
typename T>
216 ContiguousArray<T>::Clear() {
217 if (!owned_buffer_) {
218 throw Error(
"Cannot clear when using a foreign buffer; clone first");
223 template <
typename T>
225 ContiguousArray<T>::PushBack(T t) {
226 if (!owned_buffer_) {
227 throw Error(
"Cannot add element when using a foreign buffer; clone first");
229 if (size_ == capacity_) {
230 Reserve(capacity_ * 2);
232 buffer_[size_++] = t;
235 template <
typename T>
237 ContiguousArray<T>::Extend(
const std::vector<T>& other) {
238 if (!owned_buffer_) {
239 throw Error(
"Cannot add elements when using a foreign buffer; clone first");
244 std::size_t newsize = size_ + other.size();
245 if (newsize > capacity_) {
246 std::size_t newcapacity = capacity_;
247 if (newcapacity == 0) {
250 while (newcapacity <= newsize) {
253 Reserve(newcapacity);
255 std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()),
sizeof(T) * other.size());
259 template <
typename T>
261 ContiguousArray<T>::operator[](std::size_t idx) {
265 template <
typename T>
267 ContiguousArray<T>::operator[](std::size_t idx)
const {
271 template <
typename T>
273 ContiguousArray<T>::at(std::size_t idx) {
275 throw Error(
"nid out of range");
280 template <
typename T>
282 ContiguousArray<T>::at(std::size_t idx)
const {
284 throw Error(
"nid out of range");
289 template <
typename T>
291 ContiguousArray<T>::at(
int idx) {
292 if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
293 throw Error(
"nid out of range");
295 return buffer_[
static_cast<std::size_t
>(idx)];
298 template <
typename T>
300 ContiguousArray<T>::at(
int idx)
const {
301 if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
302 throw Error(
"nid out of range");
304 return buffer_[
static_cast<std::size_t
>(idx)];
307 template<
typename Container>
308 inline std::vector<std::pair<std::string, std::string> >
309 ModelParam::InitAllowUnknown(
const Container& kwargs) {
310 std::vector<std::pair<std::string, std::string>> unknowns;
311 for (
const auto& e : kwargs) {
312 if (e.first ==
"pred_transform") {
313 std::strncpy(this->pred_transform, e.second.c_str(),
314 TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
315 this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] =
'\0';
316 }
else if (e.first ==
"sigmoid_alpha") {
317 this->sigmoid_alpha = std::stof(e.second,
nullptr);
318 }
else if (e.first ==
"ratio_c") {
319 this->ratio_c = std::stof(e.second,
nullptr);
320 }
else if (e.first ==
"global_bias") {
321 this->global_bias = std::stof(e.second,
nullptr);
327 inline std::map<std::string, std::string>
328 ModelParam::__DICT__()
const {
329 std::map<std::string, std::string> ret;
330 ret.emplace(
"pred_transform", std::string(this->pred_transform));
331 ret.emplace(
"sigmoid_alpha", GetString(this->sigmoid_alpha));
332 ret.emplace(
"ratio_c", GetString(this->ratio_c));
333 ret.emplace(
"global_bias", GetString(this->global_bias));
337 inline PyBufferFrame GetPyBufferFromArray(
void* data,
const char* format,
338 std::size_t itemsize, std::size_t nitem) {
339 return PyBufferFrame{data,
const_cast<char*
>(format), itemsize, nitem};
343 template <
typename T>
344 inline const char* InferFormatString() {
347 return (std::is_unsigned<T>::value ?
"=B" :
"=b");
349 return (std::is_unsigned<T>::value ?
"=H" :
"=h");
351 if (std::is_integral<T>::value) {
352 return (std::is_unsigned<T>::value ?
"=L" :
"=l");
354 if (!std::is_floating_point<T>::value) {
355 throw Error(
"Could not infer format string");
360 if (std::is_integral<T>::value) {
361 return (std::is_unsigned<T>::value ?
"=Q" :
"=q");
363 if (!std::is_floating_point<T>::value) {
364 throw Error(
"Could not infer format string");
369 throw Error(
"Unrecognized type");
374 template <
typename T>
375 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec,
const char* format) {
376 return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format,
sizeof(T), vec->Size());
379 template <
typename T>
380 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
381 static_assert(std::is_arithmetic<T>::value,
382 "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
383 return GetPyBufferFromArray(vec, InferFormatString<T>());
386 inline PyBufferFrame GetPyBufferFromScalar(
void* data,
const char* format, std::size_t itemsize) {
387 return GetPyBufferFromArray(data, format, itemsize, 1);
390 template <
typename T>
391 inline PyBufferFrame GetPyBufferFromScalar(T* scalar,
const char* format) {
392 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
393 return GetPyBufferFromScalar(static_cast<void*>(scalar), format,
sizeof(T));
396 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
397 using T = std::underlying_type<TypeInfo>::type;
398 return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
401 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
402 using T = std::underlying_type<TaskType>::type;
403 return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
406 template <
typename T>
407 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
408 static_assert(std::is_arithmetic<T>::value,
409 "Use GetPyBufferFromScalar(scalar, format) for composite types; " 410 "specify format string manually");
411 return GetPyBufferFromScalar(scalar, InferFormatString<T>());
414 template <
typename T>
415 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
416 if (
sizeof(T) != frame.itemsize) {
417 throw Error(
"Incorrect itemsize");
419 vec->UseForeignBuffer(frame.buf, frame.nitem);
422 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
423 using T = std::underlying_type<TypeInfo>::type;
424 if (
sizeof(T) != buffer.itemsize) {
425 throw Error(
"Incorrect itemsize");
427 if (buffer.nitem != 1) {
428 throw Error(
"nitem must be 1 for a scalar");
430 T* t =
static_cast<T*
>(buffer.buf);
431 *scalar =
static_cast<TypeInfo>(*t);
434 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
435 using T = std::underlying_type<TaskType>::type;
436 if (
sizeof(T) != buffer.itemsize) {
437 throw Error(
"Incorrect itemsize");
439 if (buffer.nitem != 1) {
440 throw Error(
"nitem must be 1 for a scalar");
442 T* t =
static_cast<T*
>(buffer.buf);
443 *scalar =
static_cast<TaskType>(*t);
446 template <
typename T>
447 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
448 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
449 if (
sizeof(T) != buffer.itemsize) {
450 throw Error(
"Incorrect itemsize");
452 if (buffer.nitem != 1) {
453 throw Error(
"nitem must be 1 for a scalar");
455 T* t =
static_cast<T*
>(buffer.buf);
459 template <
typename T>
460 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
461 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
462 if (std::fread(scalar,
sizeof(T), 1, fp) < 1) {
463 throw Error(
"Could not read a scalar");
467 template <
typename T>
468 inline void WriteScalarToFile(T* scalar, FILE* fp) {
469 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
470 if (std::fwrite(scalar,
sizeof(T), 1, fp) < 1) {
471 throw Error(
"Could not write a scalar");
475 template <
typename T>
476 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
478 if (std::fread(&nelem,
sizeof(nelem), 1, fp) < 1) {
479 throw Error(
"Could not read the number of elements");
486 const auto nelem_size_t =
static_cast<std::size_t
>(nelem);
487 if (std::fread(vec->Data(),
sizeof(T), nelem_size_t, fp) < nelem_size_t) {
488 throw Error(
"Could not read an array");
492 template <
typename T>
493 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
494 static_assert(
sizeof(uint64_t) >=
sizeof(
size_t),
"size_t too large");
495 const auto nelem =
static_cast<uint64_t
>(vec->Size());
496 if (std::fwrite(&nelem,
sizeof(nelem), 1, fp) < 1) {
497 throw Error(
"Could not write the number of elements");
502 const auto nelem_size_t = vec->Size();
503 if (std::fwrite(vec->Data(),
sizeof(T), nelem_size_t, fp) < nelem_size_t) {
504 throw Error(
"Could not write an array");
508 inline void SkipOptFieldInFile(FILE* fp) {
511 ReadScalarFromFile(&elem_size, fp);
512 ReadScalarFromFile(&nelem, fp);
514 const uint64_t nbytes = elem_size * nelem;
515 TREELITE_CHECK_LE(nbytes, std::numeric_limits<long>::max());
516 if (std::fseek(fp, static_cast<long>(nbytes), SEEK_CUR) != 0) {
517 throw Error(
"Reached end of file");
521 template <
typename ThresholdType,
typename LeafOutputType>
522 Tree<ThresholdType, LeafOutputType>::Tree(
bool use_opt_field)
523 : use_opt_field_(use_opt_field)
526 template <
typename ThresholdType,
typename LeafOutputType>
527 inline Tree<ThresholdType, LeafOutputType>
528 Tree<ThresholdType, LeafOutputType>::Clone()
const {
529 Tree<ThresholdType, LeafOutputType> tree;
530 tree.num_nodes = num_nodes;
531 tree.nodes_ = nodes_.Clone();
532 tree.leaf_vector_ = leaf_vector_.Clone();
533 tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
534 tree.leaf_vector_end_ = leaf_vector_end_.Clone();
535 tree.matching_categories_ = matching_categories_.Clone();
536 tree.matching_categories_offset_ = matching_categories_offset_.Clone();
540 template <
typename ThresholdType,
typename LeafOutputType>
542 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
543 if (std::is_same<ThresholdType, float>::value) {
544 return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
546 return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
550 template <
typename ThresholdType,
typename LeafOutputType>
551 template <
typename ScalarHandler,
typename PrimitiveArrayHandler,
typename CompositeArrayHandler>
553 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
554 ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
555 CompositeArrayHandler composite_array_handler) {
556 scalar_handler(&num_nodes);
557 scalar_handler(&has_categorical_split_);
558 composite_array_handler(&nodes_, GetFormatStringForNode());
559 primitive_array_handler(&leaf_vector_);
560 primitive_array_handler(&leaf_vector_begin_);
561 primitive_array_handler(&leaf_vector_end_);
562 primitive_array_handler(&matching_categories_);
563 primitive_array_handler(&matching_categories_offset_);
566 num_opt_field_per_tree_ = 0;
567 scalar_handler(&num_opt_field_per_tree_);
570 num_opt_field_per_node_ = 0;
571 scalar_handler(&num_opt_field_per_node_);
574 template <
typename ThresholdType,
typename LeafOutputType>
575 template <
typename ScalarHandler,
typename ArrayHandler,
typename SkipOptFieldHandlerFunc>
577 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
578 ScalarHandler scalar_handler,
579 ArrayHandler array_handler,
580 SkipOptFieldHandlerFunc skip_opt_field_handler) {
581 scalar_handler(&num_nodes);
582 scalar_handler(&has_categorical_split_);
583 array_handler(&nodes_);
584 if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
585 throw Error(
"Could not load the correct number of nodes");
587 array_handler(&leaf_vector_);
588 array_handler(&leaf_vector_begin_);
589 array_handler(&leaf_vector_end_);
590 array_handler(&matching_categories_);
591 array_handler(&matching_categories_offset_);
593 if (use_opt_field_) {
595 scalar_handler(&num_opt_field_per_tree_);
597 for (int32_t i = 0; i < num_opt_field_per_tree_; ++i) {
598 skip_opt_field_handler();
602 scalar_handler(&num_opt_field_per_node_);
604 for (int32_t i = 0; i < num_opt_field_per_node_; ++i) {
605 skip_opt_field_handler();
609 num_opt_field_per_tree_ = 0;
610 num_opt_field_per_node_ = 0;
614 template <
typename ThresholdType,
typename LeafOutputType>
616 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
617 auto scalar_handler = [dest](
auto* field) {
618 dest->push_back(GetPyBufferFromScalar(field));
620 auto primitive_array_handler = [dest](
auto* field) {
621 dest->push_back(GetPyBufferFromArray(field));
623 auto composite_array_handler = [dest](
auto* field,
const char* format) {
624 dest->push_back(GetPyBufferFromArray(field, format));
626 SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
629 template <
typename ThresholdType,
typename LeafOutputType>
631 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
632 auto scalar_handler = [dest_fp](
auto* field) {
633 WriteScalarToFile(field, dest_fp);
635 auto primitive_array_handler = [dest_fp](
auto* field) {
636 WriteArrayToFile(field, dest_fp);
638 auto composite_array_handler = [dest_fp](
auto* field,
const char* format) {
639 WriteArrayToFile(field, dest_fp);
641 SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
644 template <
typename ThresholdType,
typename LeafOutputType>
645 inline std::vector<PyBufferFrame>::iterator
646 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it) {
647 std::vector<PyBufferFrame>::iterator new_it = it;
648 auto scalar_handler = [&new_it](
auto* field) {
649 InitScalarFromPyBuffer(field, *(new_it++));
651 auto array_handler = [&new_it](
auto* field) {
652 InitArrayFromPyBuffer(field, *(new_it++));
654 auto skip_opt_field_handler = [&new_it]() {
657 DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
661 template <
typename ThresholdType,
typename LeafOutputType>
663 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
664 auto scalar_handler = [src_fp](
auto* field) {
665 ReadScalarFromFile(field, src_fp);
667 auto array_handler = [src_fp](
auto* field) {
668 ReadArrayFromFile(field, src_fp);
670 auto skip_opt_field_handler = [src_fp]() {
671 SkipOptFieldInFile(src_fp);
673 DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
676 template <
typename ThresholdType,
typename LeafOutputType>
678 std::memset(
this, 0,
sizeof(
Node));
679 cleft_ = cright_ = -1;
681 info_.leaf_value =
static_cast<LeafOutputType
>(0);
682 info_.threshold =
static_cast<ThresholdType
>(0);
684 sum_hess_ = gain_ = 0.0;
685 data_count_present_ = sum_hess_present_ = gain_present_ =
false;
686 categories_list_right_child_ =
false;
687 split_type_ = SplitFeatureType::kNone;
688 cmp_ = Operator::kNone;
691 template <
typename ThresholdType,
typename LeafOutputType>
694 int nd = num_nodes++;
695 if (nodes_.Size() !=
static_cast<std::size_t
>(nd)) {
696 throw Error(
"Invariant violated: nodes_ contains incorrect number of nodes");
698 for (
int nid = nd; nid < num_nodes; ++nid) {
699 leaf_vector_begin_.PushBack(0);
700 leaf_vector_end_.PushBack(0);
701 matching_categories_offset_.PushBack(matching_categories_offset_.Back());
702 nodes_.Resize(nodes_.Size() + 1);
703 nodes_.Back().Init();
708 template <
typename ThresholdType,
typename LeafOutputType>
712 has_categorical_split_ =
false;
713 leaf_vector_.Clear();
714 leaf_vector_begin_.Resize(1, {});
715 leaf_vector_end_.Resize(1, {});
716 matching_categories_.Clear();
717 matching_categories_offset_.Resize(2, 0);
720 SetLeaf(0, static_cast<LeafOutputType>(0));
723 template <
typename ThresholdType,
typename LeafOutputType>
726 const int cleft = this->AllocNode();
727 const int cright = this->AllocNode();
728 nodes_.at(nid).cleft_ = cleft;
729 nodes_.at(nid).cright_ = cright;
732 template <
typename ThresholdType,
typename LeafOutputType>
735 int nid,
unsigned split_index, ThresholdType threshold,
bool default_left,
Operator cmp) {
736 Node& node = nodes_.at(nid);
737 if (split_index >= ((1U << 31U) - 1)) {
738 throw Error(
"split_index too big");
740 if (default_left) split_index |= (1U << 31U);
741 node.sindex_ = split_index;
742 (node.info_).threshold = threshold;
744 node.split_type_ = SplitFeatureType::kNumerical;
745 node.categories_list_right_child_ =
false;
748 template <
typename ThresholdType,
typename LeafOutputType>
751 int nid,
unsigned split_index,
bool default_left,
752 const std::vector<uint32_t>& categories_list,
bool categories_list_right_child) {
753 if (split_index >= ((1U << 31U) - 1)) {
754 throw Error(
"split_index too big");
757 const std::size_t end_oft = matching_categories_offset_.Back();
758 const std::size_t new_end_oft = end_oft + categories_list.size();
759 if (end_oft != matching_categories_.Size()) {
760 throw Error(
"Invariant violated");
762 if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
763 [end_oft](std::size_t x) {
return (x == end_oft); })) {
764 throw Error(
"Invariant violated");
767 matching_categories_.Extend(categories_list);
768 if (new_end_oft != matching_categories_.Size()) {
769 throw Error(
"Invariant violated");
771 std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
772 [new_end_oft](std::size_t& x) { x = new_end_oft; });
773 if (!matching_categories_.Empty()) {
774 std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
777 Node& node = nodes_.at(nid);
778 if (default_left) split_index |= (1U << 31U);
779 node.sindex_ = split_index;
780 node.split_type_ = SplitFeatureType::kCategorical;
781 node.categories_list_right_child_ = categories_list_right_child;
783 has_categorical_split_ =
true;
786 template <
typename ThresholdType,
typename LeafOutputType>
789 Node& node = nodes_.at(nid);
790 (node.info_).leaf_value = value;
793 node.split_type_ = SplitFeatureType::kNone;
796 template <
typename ThresholdType,
typename LeafOutputType>
799 int nid,
const std::vector<LeafOutputType>& node_leaf_vector) {
800 std::size_t begin = leaf_vector_.Size();
801 std::size_t end = begin + node_leaf_vector.size();
802 leaf_vector_.Extend(node_leaf_vector);
803 leaf_vector_begin_[nid] = begin;
804 leaf_vector_end_[nid] = end;
805 Node &node = nodes_.at(nid);
808 node.split_type_ = SplitFeatureType::kNone;
811 template <
typename ThresholdType,
typename LeafOutputType>
812 inline std::unique_ptr<Model>
814 std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
815 model->threshold_type_ = TypeToInfo<ThresholdType>();
816 model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
820 template <
typename ThresholdType,
typename LeafOutputType>
823 inline static std::unique_ptr<Model> Dispatch() {
824 return Model::Create<ThresholdType, LeafOutputType>();
828 inline std::unique_ptr<Model>
830 return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
833 template <
typename ThresholdType,
typename LeafOutputType>
836 template <
typename Func>
837 inline static auto Dispatch(
Model* model, Func func) {
841 template <
typename Func>
842 inline static auto Dispatch(
const Model* model, Func func) {
847 template <
typename Func>
849 Model::Dispatch(Func func) {
850 return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_,
this, func);
853 template <
typename Func>
855 Model::Dispatch(Func func)
const {
856 return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_,
this, func);
859 template <
typename HeaderPrimitiveFieldHandlerFunc>
861 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
862 major_ver_ = TREELITE_VER_MAJOR;
863 minor_ver_ = TREELITE_VER_MINOR;
864 patch_ver_ = TREELITE_VER_PATCH;
865 header_primitive_field_handler(&major_ver_);
866 header_primitive_field_handler(&minor_ver_);
867 header_primitive_field_handler(&patch_ver_);
868 header_primitive_field_handler(&threshold_type_);
869 header_primitive_field_handler(&leaf_output_type_);
872 template <
typename HeaderPrimitiveFieldHandlerFunc>
874 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
875 int32_t& major_ver, int32_t& minor_ver, int32_t& patch_ver,
877 header_primitive_field_handler(&major_ver);
878 header_primitive_field_handler(&minor_ver);
879 header_primitive_field_handler(&patch_ver);
880 if (major_ver != TREELITE_VER_MAJOR && !(major_ver == 2 && minor_ver == 4)) {
881 std::ostringstream oss;
882 oss <<
"Cannot deserialize model from a different major Treelite version or " 883 <<
"a version before 2.4.0." << std::endl
884 <<
"Currently running Treelite version " << TREELITE_VER_MAJOR <<
"." 885 << TREELITE_VER_MINOR <<
"." << TREELITE_VER_PATCH << std::endl
886 <<
"The model checkpoint was generated from Treelite version " << major_ver <<
"." 887 << minor_ver <<
"." << patch_ver;
888 throw Error(oss.str());
890 header_primitive_field_handler(&threshold_type);
891 header_primitive_field_handler(&leaf_output_type);
894 template <
typename ThresholdType,
typename LeafOutputType>
895 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
896 typename TreeHandlerFunc>
899 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
900 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
901 TreeHandlerFunc tree_handler) {
903 header_primitive_field_handler(&num_feature);
904 header_primitive_field_handler(&task_type);
905 header_primitive_field_handler(&average_tree_output);
906 header_composite_field_handler(&task_param,
"T{=B=?xx=I=I}");
907 header_composite_field_handler(
908 ¶m,
"T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH)
"s=f=f=f}");
911 num_opt_field_per_model_ = 0;
912 header_primitive_field_handler(&num_opt_field_per_model_);
920 template <
typename ThresholdType,
typename LeafOutputType>
921 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc,
922 typename SkipOptFieldHandlerFunc>
925 std::size_t num_tree,
926 HeaderFieldHandlerFunc header_field_handler,
927 TreeHandlerFunc tree_handler,
928 SkipOptFieldHandlerFunc skip_opt_field_handler) {
930 header_field_handler(&num_feature);
931 header_field_handler(&task_type);
932 header_field_handler(&average_tree_output);
933 header_field_handler(&task_param);
934 header_field_handler(¶m);
937 const bool use_opt_field = (major_ver_ >= 3);
939 header_field_handler(&num_opt_field_per_model_);
941 for (int32_t i = 0; i < num_opt_field_per_model_; ++i) {
942 skip_opt_field_handler();
946 num_opt_field_per_model_ = 0;
951 for (std::size_t i = 0; i < num_tree; ++i) {
952 trees.emplace_back(use_opt_field);
953 tree_handler(trees.back());
957 template <
typename ThresholdType,
typename LeafOutputType>
960 num_tree_ =
static_cast<uint64_t
>(this->trees.size());
961 auto header_primitive_field_handler = [dest](
auto* field) {
962 dest->push_back(GetPyBufferFromScalar(field));
964 auto header_composite_field_handler = [dest](
auto* field,
const char* format) {
965 dest->push_back(GetPyBufferFromScalar(field, format));
968 tree.GetPyBuffer(dest);
970 header_primitive_field_handler(&num_tree_);
971 SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
974 template <
typename ThresholdType,
typename LeafOutputType>
977 num_tree_ =
static_cast<uint64_t
>(this->trees.size());
978 auto header_primitive_field_handler = [dest_fp](
auto* field) {
979 WriteScalarToFile(field, dest_fp);
981 auto header_composite_field_handler = [dest_fp](
auto* field,
const char* format) {
982 WriteScalarToFile(field, dest_fp);
985 tree.SerializeToFile(dest_fp);
987 header_primitive_field_handler(&num_tree_);
988 SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
991 template <
typename ThresholdType,
typename LeafOutputType>
992 inline std::vector<PyBufferFrame>::iterator
994 std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) {
995 std::vector<PyBufferFrame>::iterator new_it = it;
996 auto header_field_handler = [&new_it](
auto* field) {
997 InitScalarFromPyBuffer(field, *(new_it++));
1000 auto skip_opt_field_handler = [&new_it]() {
1005 new_it = tree.InitFromPyBuffer(new_it);
1008 if (major_ver_ == 2) {
1011 constexpr std::size_t kNumFrameInHeader = 5;
1012 constexpr std::size_t kNumFramePerTree = 8;
1013 num_tree_ = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
1017 header_field_handler(&num_tree_);
1020 DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1021 TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1026 template <
typename ThresholdType,
typename LeafOutputType>
1029 ReadScalarFromFile(&num_tree_, src_fp);
1031 auto header_field_handler = [src_fp](
auto* field) {
1032 ReadScalarFromFile(field, src_fp);
1035 auto skip_opt_field_handler = [src_fp]() {
1036 SkipOptFieldInFile(src_fp);
1040 tree.DeserializeFromFile(src_fp);
1043 DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1044 TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1047 inline void InitParamAndCheck(
ModelParam* param,
1048 const std::vector<std::pair<std::string, std::string>>& cfg) {
1049 auto unknown = param->InitAllowUnknown(cfg);
1050 if (!unknown.empty()) {
1051 std::ostringstream oss;
1052 for (
const auto& kv : unknown) {
1053 oss << kv.first <<
", ";
1055 std::cerr <<
"\033[1;31mWarning: Unknown parameters found; " 1056 <<
"they have been ignored\u001B[0m: " << oss.str() << std::endl;
1061 #endif // TREELITE_TREE_IMPL_H_ void Init()
initialize the model with a single root node
Exception class that will be thrown by Treelite.
TaskType
Enum type representing the task type.
in-memory representation of a decision tree
logging facility for Treelite
Exception class used throughout the Treelite codebase.
TypeInfo
Types used by thresholds and leaf outputs.
thin wrapper for tree ensemble model
Operator
comparison operators