7 #ifndef TREELITE_TREE_IMPL_H_ 8 #define TREELITE_TREE_IMPL_H_ 17 #include <unordered_map> 28 inline std::string GetString(T x) {
29 return std::to_string(x);
33 inline std::string GetString<float>(
float x) {
34 std::ostringstream oss;
35 oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
40 inline std::string GetString<double>(
double x) {
41 std::ostringstream oss;
42 oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
51 ContiguousArray<T>::ContiguousArray()
52 : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
55 ContiguousArray<T>::~ContiguousArray() {
56 if (buffer_ && owned_buffer_) {
62 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
63 : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
64 owned_buffer_(other.owned_buffer_) {
65 other.buffer_ =
nullptr;
66 other.size_ = other.capacity_ = 0;
71 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
72 if (buffer_ && owned_buffer_) {
75 buffer_ = other.buffer_;
77 capacity_ = other.capacity_;
78 owned_buffer_ = other.owned_buffer_;
79 other.buffer_ =
nullptr;
80 other.size_ = other.capacity_ = 0;
85 inline ContiguousArray<T>
86 ContiguousArray<T>::Clone()
const {
87 ContiguousArray clone;
88 clone.buffer_ =
static_cast<T*
>(std::malloc(
sizeof(T) * capacity_));
90 throw std::runtime_error(
"Could not allocate memory for the clone");
92 std::memcpy(clone.buffer_, buffer_,
sizeof(T) * size_);
94 clone.capacity_ = capacity_;
95 clone.owned_buffer_ =
true;
101 ContiguousArray<T>::UseForeignBuffer(
void* prealloc_buf, std::size_t size) {
102 if (buffer_ && owned_buffer_) {
105 buffer_ =
static_cast<T*
>(prealloc_buf);
108 owned_buffer_ =
false;
111 template <
typename T>
113 ContiguousArray<T>::Data() {
117 template <
typename T>
119 ContiguousArray<T>::Data()
const {
123 template <
typename T>
125 ContiguousArray<T>::End() {
126 return &buffer_[Size()];
129 template <
typename T>
131 ContiguousArray<T>::End()
const {
132 return &buffer_[Size()];
135 template <
typename T>
137 ContiguousArray<T>::Back() {
138 return buffer_[Size() - 1];
141 template <
typename T>
143 ContiguousArray<T>::Back()
const {
144 return buffer_[Size() - 1];
147 template <
typename T>
149 ContiguousArray<T>::Size()
const {
153 template <
typename T>
155 ContiguousArray<T>::Empty()
const {
156 return (Size() == 0);
159 template <
typename T>
161 ContiguousArray<T>::Reserve(std::size_t newsize) {
162 if (!owned_buffer_) {
163 throw std::runtime_error(
"Cannot resize when using a foreign buffer; clone first");
165 T* newbuf =
static_cast<T*
>(std::realloc(static_cast<void*>(buffer_),
sizeof(T) * newsize));
167 throw std::runtime_error(
"Could not expand buffer");
173 template <
typename T>
175 ContiguousArray<T>::Resize(std::size_t newsize) {
176 if (!owned_buffer_) {
177 throw std::runtime_error(
"Cannot resize when using a foreign buffer; clone first");
179 if (newsize > capacity_) {
180 std::size_t newcapacity = capacity_;
181 if (newcapacity == 0) {
184 while (newcapacity <= newsize) {
187 Reserve(newcapacity);
192 template <
typename T>
194 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
195 if (!owned_buffer_) {
196 throw std::runtime_error(
"Cannot resize when using a foreign buffer; clone first");
198 std::size_t oldsize = Size();
200 for (std::size_t i = oldsize; i < newsize; ++i) {
205 template <
typename T>
207 ContiguousArray<T>::Clear() {
208 if (!owned_buffer_) {
209 throw std::runtime_error(
"Cannot clear when using a foreign buffer; clone first");
214 template <
typename T>
216 ContiguousArray<T>::PushBack(T t) {
217 if (!owned_buffer_) {
218 throw std::runtime_error(
"Cannot add element when using a foreign buffer; clone first");
220 if (size_ == capacity_) {
221 Reserve(capacity_ * 2);
223 buffer_[size_++] = t;
226 template <
typename T>
228 ContiguousArray<T>::Extend(
const std::vector<T>& other) {
229 if (!owned_buffer_) {
230 throw std::runtime_error(
"Cannot add elements when using a foreign buffer; clone first");
235 std::size_t newsize = size_ + other.size();
236 if (newsize > capacity_) {
237 std::size_t newcapacity = capacity_;
238 if (newcapacity == 0) {
241 while (newcapacity <= newsize) {
244 Reserve(newcapacity);
246 std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()),
sizeof(T) * other.size());
250 template <
typename T>
252 ContiguousArray<T>::operator[](std::size_t idx) {
256 template <
typename T>
258 ContiguousArray<T>::operator[](std::size_t idx)
const {
262 template <
typename T>
264 ContiguousArray<T>::at(std::size_t idx) {
266 throw std::runtime_error(
"nid out of range");
271 template <
typename T>
273 ContiguousArray<T>::at(std::size_t idx)
const {
275 throw std::runtime_error(
"nid out of range");
280 template <
typename T>
282 ContiguousArray<T>::at(
int idx) {
283 if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
284 throw std::runtime_error(
"nid out of range");
286 return buffer_[
static_cast<std::size_t
>(idx)];
289 template <
typename T>
291 ContiguousArray<T>::at(
int idx)
const {
292 if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
293 throw std::runtime_error(
"nid out of range");
295 return buffer_[
static_cast<std::size_t
>(idx)];
298 template<
typename Container>
299 inline std::vector<std::pair<std::string, std::string> >
300 ModelParam::InitAllowUnknown(
const Container& kwargs) {
301 std::vector<std::pair<std::string, std::string>> unknowns;
302 for (
const auto& e : kwargs) {
303 if (e.first ==
"pred_transform") {
304 std::strncpy(this->pred_transform, e.second.c_str(),
305 TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
306 this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] =
'\0';
307 }
else if (e.first ==
"sigmoid_alpha") {
308 this->sigmoid_alpha = std::stof(e.second,
nullptr);
309 }
else if (e.first ==
"ratio_c") {
310 this->ratio_c = std::stof(e.second,
nullptr);
311 }
else if (e.first ==
"global_bias") {
312 this->global_bias = std::stof(e.second,
nullptr);
318 inline std::map<std::string, std::string>
319 ModelParam::__DICT__()
const {
320 std::map<std::string, std::string> ret;
321 ret.emplace(
"pred_transform", std::string(this->pred_transform));
322 ret.emplace(
"sigmoid_alpha", GetString(this->sigmoid_alpha));
323 ret.emplace(
"ratio_c", GetString(this->ratio_c));
324 ret.emplace(
"global_bias", GetString(this->global_bias));
328 inline PyBufferFrame GetPyBufferFromArray(
void* data,
const char* format,
329 std::size_t itemsize, std::size_t nitem) {
330 return PyBufferFrame{data,
const_cast<char*
>(format), itemsize, nitem};
334 template <
typename T>
335 inline const char* InferFormatString() {
338 return (std::is_unsigned<T>::value ?
"=B" :
"=b");
340 return (std::is_unsigned<T>::value ?
"=H" :
"=h");
342 if (std::is_integral<T>::value) {
343 return (std::is_unsigned<T>::value ?
"=L" :
"=l");
345 if (!std::is_floating_point<T>::value) {
346 throw std::runtime_error(
"Could not infer format string");
351 if (std::is_integral<T>::value) {
352 return (std::is_unsigned<T>::value ?
"=Q" :
"=q");
354 if (!std::is_floating_point<T>::value) {
355 throw std::runtime_error(
"Could not infer format string");
360 throw std::runtime_error(
"Unrecognized type");
365 template <
typename T>
366 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec,
const char* format) {
367 return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format,
sizeof(T), vec->Size());
370 template <
typename T>
371 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
372 static_assert(std::is_arithmetic<T>::value,
373 "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
374 return GetPyBufferFromArray(vec, InferFormatString<T>());
377 inline PyBufferFrame GetPyBufferFromScalar(
void* data,
const char* format, std::size_t itemsize) {
378 return GetPyBufferFromArray(data, format, itemsize, 1);
381 template <
typename T>
382 inline PyBufferFrame GetPyBufferFromScalar(T* scalar,
const char* format) {
383 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
384 return GetPyBufferFromScalar(static_cast<void*>(scalar), format,
sizeof(T));
387 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
388 using T = std::underlying_type<TypeInfo>::type;
389 return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
392 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
393 using T = std::underlying_type<TaskType>::type;
394 return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
397 template <
typename T>
398 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
399 static_assert(std::is_arithmetic<T>::value,
400 "Use GetPyBufferFromScalar(scalar, format) for composite types; " 401 "specify format string manually");
402 return GetPyBufferFromScalar(scalar, InferFormatString<T>());
405 template <
typename T>
406 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
407 if (
sizeof(T) != frame.itemsize) {
408 throw std::runtime_error(
"Incorrect itemsize");
410 vec->UseForeignBuffer(frame.buf, frame.nitem);
413 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
414 using T = std::underlying_type<TypeInfo>::type;
415 if (
sizeof(T) != buffer.itemsize) {
416 throw std::runtime_error(
"Incorrect itemsize");
418 if (buffer.nitem != 1) {
419 throw std::runtime_error(
"nitem must be 1 for a scalar");
421 T* t =
static_cast<T*
>(buffer.buf);
422 *scalar =
static_cast<TypeInfo>(*t);
425 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
426 using T = std::underlying_type<TaskType>::type;
427 if (
sizeof(T) != buffer.itemsize) {
428 throw std::runtime_error(
"Incorrect itemsize");
430 if (buffer.nitem != 1) {
431 throw std::runtime_error(
"nitem must be 1 for a scalar");
433 T* t =
static_cast<T*
>(buffer.buf);
434 *scalar =
static_cast<TaskType>(*t);
437 template <
typename T>
438 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
439 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
440 if (
sizeof(T) != buffer.itemsize) {
441 throw std::runtime_error(
"Incorrect itemsize");
443 if (buffer.nitem != 1) {
444 throw std::runtime_error(
"nitem must be 1 for a scalar");
446 T* t =
static_cast<T*
>(buffer.buf);
450 template <
typename T>
451 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
452 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
453 if (std::fread(scalar,
sizeof(T), 1, fp) < 1) {
454 throw std::runtime_error(
"Could not read a scalar");
458 template <
typename T>
459 inline void WriteScalarToFile(T* scalar, FILE* fp) {
460 static_assert(std::is_standard_layout<T>::value,
"T must be in the standard layout");
461 if (std::fwrite(scalar,
sizeof(T), 1, fp) < 1) {
462 throw std::runtime_error(
"Could not write a scalar");
466 template <
typename T>
467 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
469 if (std::fread(&nelem,
sizeof(nelem), 1, fp) < 1) {
470 throw std::runtime_error(
"Could not read the number of elements");
477 const auto nelem_size_t =
static_cast<std::size_t
>(nelem);
478 if (std::fread(vec->Data(),
sizeof(T), nelem_size_t, fp) < nelem_size_t) {
479 throw std::runtime_error(
"Could not read an array");
483 template <
typename T>
484 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
485 static_assert(
sizeof(uint64_t) >=
sizeof(
size_t),
"size_t too large");
486 const auto nelem =
static_cast<uint64_t
>(vec->Size());
487 if (std::fwrite(&nelem,
sizeof(nelem), 1, fp) < 1) {
488 throw std::runtime_error(
"Could not write the number of elements");
493 const auto nelem_size_t = vec->Size();
494 if (std::fwrite(vec->Data(),
sizeof(T), nelem_size_t, fp) < nelem_size_t) {
495 throw std::runtime_error(
"Could not write an array");
499 template <
typename ThresholdType,
typename LeafOutputType>
500 inline Tree<ThresholdType, LeafOutputType>
501 Tree<ThresholdType, LeafOutputType>::Clone()
const {
502 Tree<ThresholdType, LeafOutputType> tree;
503 tree.num_nodes = num_nodes;
504 tree.nodes_ = nodes_.Clone();
505 tree.leaf_vector_ = leaf_vector_.Clone();
506 tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
507 tree.leaf_vector_end_ = leaf_vector_end_.Clone();
508 tree.matching_categories_ = matching_categories_.Clone();
509 tree.matching_categories_offset_ = matching_categories_offset_.Clone();
513 template <
typename ThresholdType,
typename LeafOutputType>
515 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
516 if (std::is_same<ThresholdType, float>::value) {
517 return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
519 return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
523 constexpr std::size_t kNumFramePerTree = 7;
525 template <
typename ThresholdType,
typename LeafOutputType>
526 template <
typename ScalarHandler,
typename PrimitiveArrayHandler,
typename CompositeArrayHandler>
528 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
529 ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
530 CompositeArrayHandler composite_array_handler) {
531 scalar_handler(&num_nodes);
532 composite_array_handler(&nodes_, GetFormatStringForNode());
533 primitive_array_handler(&leaf_vector_);
534 primitive_array_handler(&leaf_vector_begin_);
535 primitive_array_handler(&leaf_vector_end_);
536 primitive_array_handler(&matching_categories_);
537 primitive_array_handler(&matching_categories_offset_);
540 template <
typename ThresholdType,
typename LeafOutputType>
541 template <
typename ScalarHandler,
typename ArrayHandler>
543 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
544 ScalarHandler scalar_handler, ArrayHandler array_handler) {
545 scalar_handler(&num_nodes);
546 array_handler(&nodes_);
547 if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
548 throw std::runtime_error(
"Could not load the correct number of nodes");
550 array_handler(&leaf_vector_);
551 array_handler(&leaf_vector_begin_);
552 array_handler(&leaf_vector_end_);
553 array_handler(&matching_categories_);
554 array_handler(&matching_categories_offset_);
557 template <
typename ThresholdType,
typename LeafOutputType>
559 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
560 auto scalar_handler = [dest](
auto* field) {
561 dest->push_back(GetPyBufferFromScalar(field));
563 auto primitive_array_handler = [dest](
auto* field) {
564 dest->push_back(GetPyBufferFromArray(field));
566 auto composite_array_handler = [dest](
auto* field,
const char* format) {
567 dest->push_back(GetPyBufferFromArray(field, format));
569 SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
572 template <
typename ThresholdType,
typename LeafOutputType>
574 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
575 auto scalar_handler = [dest_fp](
auto* field) {
576 WriteScalarToFile(field, dest_fp);
578 auto primitive_array_handler = [dest_fp](
auto* field) {
579 WriteArrayToFile(field, dest_fp);
581 auto composite_array_handler = [dest_fp](
auto* field,
const char* format) {
582 WriteArrayToFile(field, dest_fp);
584 SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
587 template <
typename ThresholdType,
typename LeafOutputType>
589 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
590 std::vector<PyBufferFrame>::iterator end) {
591 if (std::distance(begin, end) != kNumFramePerTree) {
592 throw std::runtime_error(
"Wrong number of frames specified");
594 auto scalar_handler = [&begin](
auto* field) {
595 InitScalarFromPyBuffer(field, *begin++);
597 auto array_handler = [&begin](
auto* field) {
598 InitArrayFromPyBuffer(field, *begin++);
600 DeserializeTemplate(scalar_handler, array_handler);
603 template <
typename ThresholdType,
typename LeafOutputType>
605 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
606 auto scalar_handler = [src_fp](
auto* field) {
607 ReadScalarFromFile(field, src_fp);
609 auto array_handler = [src_fp](
auto* field) {
610 ReadArrayFromFile(field, src_fp);
612 DeserializeTemplate(scalar_handler, array_handler);
615 template <
typename ThresholdType,
typename LeafOutputType>
617 std::memset(
this, 0,
sizeof(
Node));
618 cleft_ = cright_ = -1;
620 info_.leaf_value =
static_cast<LeafOutputType
>(0);
621 info_.threshold =
static_cast<ThresholdType
>(0);
623 sum_hess_ = gain_ = 0.0;
624 data_count_present_ = sum_hess_present_ = gain_present_ =
false;
625 categories_list_right_child_ =
false;
626 split_type_ = SplitFeatureType::kNone;
627 cmp_ = Operator::kNone;
630 template <
typename ThresholdType,
typename LeafOutputType>
633 int nd = num_nodes++;
634 if (nodes_.Size() !=
static_cast<std::size_t
>(nd)) {
635 throw std::runtime_error(
"Invariant violated: nodes_ contains incorrect number of nodes");
637 for (
int nid = nd; nid < num_nodes; ++nid) {
638 leaf_vector_begin_.PushBack(0);
639 leaf_vector_end_.PushBack(0);
640 matching_categories_offset_.PushBack(matching_categories_offset_.Back());
641 nodes_.Resize(nodes_.Size() + 1);
642 nodes_.Back().Init();
647 template <
typename ThresholdType,
typename LeafOutputType>
651 leaf_vector_.Clear();
652 leaf_vector_begin_.Resize(1, {});
653 leaf_vector_end_.Resize(1, {});
654 matching_categories_.Clear();
655 matching_categories_offset_.Resize(2, 0);
658 SetLeaf(0, static_cast<LeafOutputType>(0));
661 template <
typename ThresholdType,
typename LeafOutputType>
664 const int cleft = this->AllocNode();
665 const int cright = this->AllocNode();
666 nodes_.at(nid).cleft_ = cleft;
667 nodes_.at(nid).cright_ = cright;
670 template <
typename ThresholdType,
typename LeafOutputType>
671 inline std::vector<unsigned>
673 std::unordered_map<unsigned, bool> tmp;
674 for (
int nid = 0; nid < num_nodes; ++nid) {
676 if (type != SplitFeatureType::kNone) {
677 const bool flag = (type == SplitFeatureType::kCategorical);
678 const uint32_t split_index = SplitIndex(nid);
679 if (tmp.count(split_index) == 0) {
680 tmp[split_index] = flag;
682 if (tmp[split_index] != flag) {
683 throw std::runtime_error(
"Feature " + std::to_string(split_index) +
684 " cannot be simultaneously be categorical and numerical.");
689 std::vector<unsigned> result;
690 for (
const auto& kv : tmp) {
692 result.push_back(kv.first);
695 std::sort(result.begin(), result.end());
699 template <
typename ThresholdType,
typename LeafOutputType>
702 int nid,
unsigned split_index, ThresholdType threshold,
bool default_left,
Operator cmp) {
703 Node& node = nodes_.at(nid);
704 if (split_index >= ((1U << 31U) - 1)) {
705 throw std::runtime_error(
"split_index too big");
707 if (default_left) split_index |= (1U << 31U);
708 node.sindex_ = split_index;
709 (node.info_).threshold = threshold;
711 node.split_type_ = SplitFeatureType::kNumerical;
712 node.categories_list_right_child_ =
false;
715 template <
typename ThresholdType,
typename LeafOutputType>
718 int nid,
unsigned split_index,
bool default_left,
719 const std::vector<uint32_t>& categories_list,
bool categories_list_right_child) {
720 if (split_index >= ((1U << 31U) - 1)) {
721 throw std::runtime_error(
"split_index too big");
724 const std::size_t end_oft = matching_categories_offset_.Back();
725 const std::size_t new_end_oft = end_oft + categories_list.size();
726 if (end_oft != matching_categories_.Size()) {
727 throw std::runtime_error(
"Invariant violated");
729 if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
730 [end_oft](std::size_t x) {
return (x == end_oft); })) {
731 throw std::runtime_error(
"Invariant violated");
734 matching_categories_.Extend(categories_list);
735 if (new_end_oft != matching_categories_.Size()) {
736 throw std::runtime_error(
"Invariant violated");
738 std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
739 [new_end_oft](std::size_t& x) { x = new_end_oft; });
740 if (!matching_categories_.Empty()) {
741 std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
744 Node& node = nodes_.at(nid);
745 if (default_left) split_index |= (1U << 31U);
746 node.sindex_ = split_index;
747 node.split_type_ = SplitFeatureType::kCategorical;
748 node.categories_list_right_child_ = categories_list_right_child;
751 template <
typename ThresholdType,
typename LeafOutputType>
754 Node& node = nodes_.at(nid);
755 (node.info_).leaf_value = value;
758 node.split_type_ = SplitFeatureType::kNone;
761 template <
typename ThresholdType,
typename LeafOutputType>
764 int nid,
const std::vector<LeafOutputType>& node_leaf_vector) {
765 std::size_t begin = leaf_vector_.Size();
766 std::size_t end = begin + node_leaf_vector.size();
767 leaf_vector_.Extend(node_leaf_vector);
768 leaf_vector_begin_[nid] = begin;
769 leaf_vector_end_[nid] = end;
770 Node &node = nodes_.at(nid);
773 node.split_type_ = SplitFeatureType::kNone;
776 template <
typename ThresholdType,
typename LeafOutputType>
777 inline std::unique_ptr<Model>
779 std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
780 model->threshold_type_ = TypeToInfo<ThresholdType>();
781 model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
785 template <
typename ThresholdType,
typename LeafOutputType>
788 inline static std::unique_ptr<Model> Dispatch() {
789 return Model::Create<ThresholdType, LeafOutputType>();
793 inline std::unique_ptr<Model>
795 return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
798 template <
typename ThresholdType,
typename LeafOutputType>
801 template <
typename Func>
802 inline static auto Dispatch(
Model* model, Func func) {
806 template <
typename Func>
807 inline static auto Dispatch(
const Model* model, Func func) {
812 template <
typename Func>
814 Model::Dispatch(Func func) {
815 return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_,
this, func);
818 template <
typename Func>
820 Model::Dispatch(Func func)
const {
821 return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_,
this, func);
824 template <
typename HeaderPrimitiveFieldHandlerFunc>
826 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
827 header_primitive_field_handler(&major_ver_);
828 header_primitive_field_handler(&minor_ver_);
829 header_primitive_field_handler(&patch_ver_);
830 header_primitive_field_handler(&threshold_type_);
831 header_primitive_field_handler(&leaf_output_type_);
834 template <
typename HeaderPrimitiveFieldHandlerFunc>
836 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
838 int major_ver, minor_ver, patch_ver;
839 header_primitive_field_handler(&major_ver);
840 header_primitive_field_handler(&minor_ver);
841 header_primitive_field_handler(&patch_ver);
842 if (major_ver != TREELITE_VER_MAJOR || minor_ver != TREELITE_VER_MINOR) {
843 std::ostringstream oss;
844 oss <<
"Cannot deserialize model from a different version of Treelite." << std::endl
845 <<
"Currently running Treelite version " << TREELITE_VER_MAJOR <<
"." 846 << TREELITE_VER_MINOR <<
"." << TREELITE_VER_PATCH << std::endl
847 <<
"The model checkpoint was generated from Treelite version " << major_ver <<
"." 848 << minor_ver <<
"." << patch_ver;
849 throw std::runtime_error(oss.str());
851 header_primitive_field_handler(&threshold_type);
852 header_primitive_field_handler(&leaf_output_type);
855 template <
typename ThresholdType,
typename LeafOutputType>
856 template <
typename HeaderPrimitiveFieldHandlerFunc,
typename HeaderCompositeFieldHandlerFunc,
857 typename TreeHandlerFunc>
860 HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
861 HeaderCompositeFieldHandlerFunc header_composite_field_handler,
862 TreeHandlerFunc tree_handler) {
864 header_primitive_field_handler(&num_feature);
865 header_primitive_field_handler(&task_type);
866 header_primitive_field_handler(&average_tree_output);
867 header_composite_field_handler(&task_param,
"T{=B=?xx=I=I}");
868 header_composite_field_handler(
869 ¶m,
"T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH)
"s=f=f=f}");
877 template <
typename ThresholdType,
typename LeafOutputType>
878 template <
typename HeaderFieldHandlerFunc,
typename TreeHandlerFunc>
881 std::size_t num_tree,
882 HeaderFieldHandlerFunc header_field_handler,
883 TreeHandlerFunc tree_handler) {
885 header_field_handler(&num_feature);
886 header_field_handler(&task_type);
887 header_field_handler(&average_tree_output);
888 header_field_handler(&task_param);
889 header_field_handler(¶m);
892 for (std::size_t i = 0; i < num_tree; ++i) {
893 trees.emplace_back();
894 tree_handler(trees.back());
898 template <
typename ThresholdType,
typename LeafOutputType>
901 auto header_primitive_field_handler = [dest](
auto* field) {
902 dest->push_back(GetPyBufferFromScalar(field));
904 auto header_composite_field_handler = [dest](
auto* field,
const char* format) {
905 dest->push_back(GetPyBufferFromScalar(field, format));
908 tree.GetPyBuffer(dest);
910 SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
913 template <
typename ThresholdType,
typename LeafOutputType>
916 const auto num_tree =
static_cast<uint64_t
>(this->trees.size());
917 WriteScalarToFile(&num_tree, dest_fp);
918 auto header_primitive_field_handler = [dest_fp](
auto* field) {
919 WriteScalarToFile(field, dest_fp);
921 auto header_composite_field_handler = [dest_fp](
auto* field,
const char* format) {
922 WriteScalarToFile(field, dest_fp);
925 tree.SerializeToFile(dest_fp);
927 SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
930 template <
typename ThresholdType,
typename LeafOutputType>
933 std::vector<PyBufferFrame>::iterator begin, std::vector<PyBufferFrame>::iterator end) {
934 const std::size_t num_frame = std::distance(begin, end);
935 constexpr std::size_t kNumFrameInHeader = 5;
936 if (num_frame < kNumFrameInHeader || (num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) {
937 throw std::runtime_error(
"Wrong number of frames");
939 const std::size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
941 auto header_field_handler = [&begin](
auto* field) {
942 InitScalarFromPyBuffer(field, *begin++);
947 tree.InitFromPyBuffer(begin, begin + kNumFramePerTree);
948 begin += kNumFramePerTree;
952 DeserializeTemplate(num_tree, header_field_handler, tree_handler);
955 template <
typename ThresholdType,
typename LeafOutputType>
959 ReadScalarFromFile(&num_tree, src_fp);
961 auto header_field_handler = [src_fp](
auto* field) {
962 ReadScalarFromFile(field, src_fp);
966 tree.DeserializeFromFile(src_fp);
969 DeserializeTemplate(num_tree, header_field_handler, tree_handler);
972 inline void InitParamAndCheck(
ModelParam* param,
973 const std::vector<std::pair<std::string, std::string>>& cfg) {
974 auto unknown = param->InitAllowUnknown(cfg);
975 if (!unknown.empty()) {
976 std::ostringstream oss;
977 for (
const auto& kv : unknown) {
978 oss << kv.first <<
", ";
980 std::cerr <<
"\033[1;31mWarning: Unknown parameters found; " 981 <<
"they have been ignored\u001B[0m: " << oss.str() << std::endl;
986 #endif // TREELITE_TREE_IMPL_H_ SplitFeatureType
feature split type
void Init()
initialize the model with a single root node
TaskType
Enum type representing the task type.
in-memory representation of a decision tree
TypeInfo
Types used by thresholds and leaf outputs.
thin wrapper for tree ensemble model
Operator
comparison operators