Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <algorithm>
11 #include <limits>
12 #include <memory>
13 #include <map>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 #include <unordered_map>
18 #include <sstream>
19 #include <iomanip>
20 #include <typeinfo>
21 #include <stdexcept>
22 #include <iostream>
23 
24 namespace {
25 
26 template <typename T>
27 inline std::string GetString(T x) {
28  return std::to_string(x);
29 }
30 
31 template <>
32 inline std::string GetString<float>(float x) {
33  std::ostringstream oss;
34  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
35  return oss.str();
36 }
37 
38 template <>
39 inline std::string GetString<double>(double x) {
40  std::ostringstream oss;
41  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
42  return oss.str();
43 }
44 
45 } // anonymous namespace
46 
47 namespace treelite {
48 
49 template <typename T>
50 ContiguousArray<T>::ContiguousArray()
51  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
52 
53 template <typename T>
54 ContiguousArray<T>::~ContiguousArray() {
55  if (buffer_ && owned_buffer_) {
56  std::free(buffer_);
57  }
58 }
59 
60 template <typename T>
61 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
62  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
63  owned_buffer_(other.owned_buffer_) {
64  other.buffer_ = nullptr;
65  other.size_ = other.capacity_ = 0;
66 }
67 
68 template <typename T>
69 ContiguousArray<T>&
70 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
71  if (buffer_ && owned_buffer_) {
72  std::free(buffer_);
73  }
74  buffer_ = other.buffer_;
75  size_ = other.size_;
76  capacity_ = other.capacity_;
77  owned_buffer_ = other.owned_buffer_;
78  other.buffer_ = nullptr;
79  other.size_ = other.capacity_ = 0;
80  return *this;
81 }
82 
83 template <typename T>
84 inline ContiguousArray<T>
85 ContiguousArray<T>::Clone() const {
86  ContiguousArray clone;
87  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
88  if (!clone.buffer_) {
89  throw std::runtime_error("Could not allocate memory for the clone");
90  }
91  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
92  clone.size_ = size_;
93  clone.capacity_ = capacity_;
94  clone.owned_buffer_ = true;
95  return clone;
96 }
97 
98 template <typename T>
99 inline void
100 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, size_t size) {
101  if (buffer_ && owned_buffer_) {
102  std::free(buffer_);
103  }
104  buffer_ = static_cast<T*>(prealloc_buf);
105  size_ = size;
106  capacity_ = size;
107  owned_buffer_ = false;
108 }
109 
110 template <typename T>
111 inline T*
112 ContiguousArray<T>::Data() {
113  return buffer_;
114 }
115 
116 template <typename T>
117 inline const T*
118 ContiguousArray<T>::Data() const {
119  return buffer_;
120 }
121 
122 template <typename T>
123 inline T*
124 ContiguousArray<T>::End() {
125  return &buffer_[Size()];
126 }
127 
128 template <typename T>
129 inline const T*
130 ContiguousArray<T>::End() const {
131  return &buffer_[Size()];
132 }
133 
134 template <typename T>
135 inline T&
136 ContiguousArray<T>::Back() {
137  return buffer_[Size() - 1];
138 }
139 
140 template <typename T>
141 inline const T&
142 ContiguousArray<T>::Back() const {
143  return buffer_[Size() - 1];
144 }
145 
146 template <typename T>
147 inline size_t
148 ContiguousArray<T>::Size() const {
149  return size_;
150 }
151 
152 template <typename T>
153 inline void
154 ContiguousArray<T>::Reserve(size_t newsize) {
155  if (!owned_buffer_) {
156  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
157  }
158  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
159  if (!newbuf) {
160  throw std::runtime_error("Could not expand buffer");
161  }
162  buffer_ = newbuf;
163  capacity_ = newsize;
164 }
165 
166 template <typename T>
167 inline void
168 ContiguousArray<T>::Resize(size_t newsize) {
169  if (!owned_buffer_) {
170  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
171  }
172  if (newsize > capacity_) {
173  size_t newcapacity = capacity_;
174  if (newcapacity == 0) {
175  newcapacity = 1;
176  }
177  while (newcapacity <= newsize) {
178  newcapacity *= 2;
179  }
180  Reserve(newcapacity);
181  }
182  size_ = newsize;
183 }
184 
185 template <typename T>
186 inline void
187 ContiguousArray<T>::Resize(size_t newsize, T t) {
188  if (!owned_buffer_) {
189  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
190  }
191  size_t oldsize = Size();
192  Resize(newsize);
193  for (size_t i = oldsize; i < newsize; ++i) {
194  buffer_[i] = t;
195  }
196 }
197 
198 template <typename T>
199 inline void
200 ContiguousArray<T>::Clear() {
201  if (!owned_buffer_) {
202  throw std::runtime_error("Cannot clear when using a foreign buffer; clone first");
203  }
204  Resize(0);
205 }
206 
207 template <typename T>
208 inline void
209 ContiguousArray<T>::PushBack(T t) {
210  if (!owned_buffer_) {
211  throw std::runtime_error("Cannot add element when using a foreign buffer; clone first");
212  }
213  if (size_ == capacity_) {
214  Reserve(capacity_ * 2);
215  }
216  buffer_[size_++] = t;
217 }
218 
219 template <typename T>
220 inline void
221 ContiguousArray<T>::Extend(const std::vector<T>& other) {
222  if (!owned_buffer_) {
223  throw std::runtime_error("Cannot add elements when using a foreign buffer; clone first");
224  }
225  size_t newsize = size_ + other.size();
226  if (newsize > capacity_) {
227  size_t newcapacity = capacity_;
228  if (newcapacity == 0) {
229  newcapacity = 1;
230  }
231  while (newcapacity <= newsize) {
232  newcapacity *= 2;
233  }
234  Reserve(newcapacity);
235  }
236  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
237  size_ = newsize;
238 }
239 
240 template <typename T>
241 inline T&
242 ContiguousArray<T>::operator[](size_t idx) {
243  return buffer_[idx];
244 }
245 
246 template <typename T>
247 inline const T&
248 ContiguousArray<T>::operator[](size_t idx) const {
249  return buffer_[idx];
250 }
251 
252 template <typename T>
253 inline T&
254 ContiguousArray<T>::at(size_t idx) {
255  if (idx >= Size()) {
256  throw std::runtime_error("nid out of range");
257  }
258  return buffer_[idx];
259 }
260 
261 template <typename T>
262 inline const T&
263 ContiguousArray<T>::at(size_t idx) const {
264  if (idx >= Size()) {
265  throw std::runtime_error("nid out of range");
266  }
267  return buffer_[idx];
268 }
269 
270 template <typename T>
271 inline T&
272 ContiguousArray<T>::at(int idx) {
273  if (idx < 0 || static_cast<size_t>(idx) >= Size()) {
274  throw std::runtime_error("nid out of range");
275  }
276  return buffer_[static_cast<size_t>(idx)];
277 }
278 
279 template <typename T>
280 inline const T&
281 ContiguousArray<T>::at(int idx) const {
282  if (idx < 0 || static_cast<size_t>(idx) >= Size()) {
283  throw std::runtime_error("nid out of range");
284  }
285  return buffer_[static_cast<size_t>(idx)];
286 }
287 
288 template<typename Container>
289 inline std::vector<std::pair<std::string, std::string> >
290 ModelParam::InitAllowUnknown(const Container& kwargs) {
291  std::vector<std::pair<std::string, std::string>> unknowns;
292  for (const auto& e : kwargs) {
293  if (e.first == "pred_transform") {
294  std::strncpy(this->pred_transform, e.second.c_str(),
295  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
296  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
297  } else if (e.first == "sigmoid_alpha") {
298  this->sigmoid_alpha = dmlc::stof(e.second, nullptr);
299  } else if (e.first == "global_bias") {
300  this->global_bias = dmlc::stof(e.second, nullptr);
301  }
302  }
303  return unknowns;
304 }
305 
306 inline std::map<std::string, std::string>
307 ModelParam::__DICT__() const {
308  std::map<std::string, std::string> ret;
309  ret.emplace("pred_transform", std::string(this->pred_transform));
310  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
311  ret.emplace("global_bias", GetString(this->global_bias));
312  return ret;
313 }
314 
315 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
316  size_t itemsize, size_t nitem) {
317  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
318 }
319 
320 // Infer format string from data type
321 template <typename T>
322 inline const char* InferFormatString() {
323  switch (sizeof(T)) {
324  case 1:
325  return (std::is_unsigned<T>::value ? "=B" : "=b");
326  case 2:
327  return (std::is_unsigned<T>::value ? "=H" : "=h");
328  case 4:
329  if (std::is_integral<T>::value) {
330  return (std::is_unsigned<T>::value ? "=L" : "=l");
331  } else {
332  if (!std::is_floating_point<T>::value) {
333  throw std::runtime_error("Could not infer format string");
334  }
335  return "=f";
336  }
337  case 8:
338  if (std::is_integral<T>::value) {
339  return (std::is_unsigned<T>::value ? "=Q" : "=q");
340  } else {
341  if (!std::is_floating_point<T>::value) {
342  throw std::runtime_error("Could not infer format string");
343  }
344  return "=d";
345  }
346  default:
347  throw std::runtime_error("Unrecognized type");
348  }
349  return nullptr;
350 }
351 
352 template <typename T>
353 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
354  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
355 }
356 
357 template <typename T>
358 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
359  static_assert(std::is_arithmetic<T>::value,
360  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
361  return GetPyBufferFromArray(vec, InferFormatString<T>());
362 }
363 
364 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, size_t itemsize) {
365  return GetPyBufferFromArray(data, format, itemsize, 1);
366 }
367 
368 template <typename T>
369 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
370  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
371 }
372 
373 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
374  using T = std::underlying_type<TypeInfo>::type;
375  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
376 }
377 
378 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
379  using T = std::underlying_type<TaskType>::type;
380  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
381 }
382 
383 template <typename T>
384 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
385  static_assert(std::is_arithmetic<T>::value,
386  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
387  "specify format string manually");
388  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
389 }
390 
391 template <typename T>
392 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame buffer) {
393  if (sizeof(T) != buffer.itemsize) {
394  throw std::runtime_error("Incorrect itemsize");
395  }
396  vec->UseForeignBuffer(buffer.buf, buffer.nitem);
397 }
398 
399 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
400  using T = std::underlying_type<TypeInfo>::type;
401  if (sizeof(T) != buffer.itemsize) {
402  throw std::runtime_error("Incorrect itemsize");
403  }
404  if (buffer.nitem != 1) {
405  throw std::runtime_error("nitem must be 1 for a scalar");
406  }
407  T* t = static_cast<T*>(buffer.buf);
408  *scalar = static_cast<TypeInfo>(*t);
409 }
410 
411 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
412  using T = std::underlying_type<TaskType>::type;
413  if (sizeof(T) != buffer.itemsize) {
414  throw std::runtime_error("Incorrect itemsize");
415  }
416  if (buffer.nitem != 1) {
417  throw std::runtime_error("nitem must be 1 for a scalar");
418  }
419  T* t = static_cast<T*>(buffer.buf);
420  *scalar = static_cast<TaskType>(*t);
421 }
422 
423 template <typename T>
424 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
425  if (sizeof(T) != buffer.itemsize) {
426  throw std::runtime_error("Incorrect itemsize");
427  }
428  if (buffer.nitem != 1) {
429  throw std::runtime_error("nitem must be 1 for a scalar");
430  }
431  T* t = static_cast<T*>(buffer.buf);
432  *scalar = *t;
433 }
434 
435 template <typename ThresholdType, typename LeafOutputType>
436 inline Tree<ThresholdType, LeafOutputType>
437 Tree<ThresholdType, LeafOutputType>::Clone() const {
438  Tree<ThresholdType, LeafOutputType> tree;
439  tree.num_nodes = num_nodes;
440  tree.nodes_ = nodes_.Clone();
441  tree.leaf_vector_ = leaf_vector_.Clone();
442  tree.leaf_vector_offset_ = leaf_vector_offset_.Clone();
443  tree.matching_categories_ = matching_categories_.Clone();
444  tree.matching_categories_offset_ = matching_categories_offset_.Clone();
445  return tree;
446 }
447 
448 template <typename ThresholdType, typename LeafOutputType>
449 inline const char*
450 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
451  if (std::is_same<ThresholdType, float>::value) {
452  return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
453  } else {
454  return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
455  }
456 }
457 
458 constexpr size_t kNumFramePerTree = 6;
459 
460 template <typename ThresholdType, typename LeafOutputType>
461 inline void
462 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
463  dest->push_back(GetPyBufferFromScalar(&num_nodes));
464  dest->push_back(GetPyBufferFromArray(&nodes_, GetFormatStringForNode()));
465  dest->push_back(GetPyBufferFromArray(&leaf_vector_));
466  dest->push_back(GetPyBufferFromArray(&leaf_vector_offset_));
467  dest->push_back(GetPyBufferFromArray(&matching_categories_));
468  dest->push_back(GetPyBufferFromArray(&matching_categories_offset_));
469 }
470 
471 template <typename ThresholdType, typename LeafOutputType>
472 inline void
473 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(
474  std::vector<PyBufferFrame>::iterator begin, std::vector<PyBufferFrame>::iterator end) {
475  if (std::distance(begin, end) != kNumFramePerTree) {
476  throw std::runtime_error("Wrong number of frames specified");
477  }
478  InitScalarFromPyBuffer(&num_nodes, *begin++);
479  InitArrayFromPyBuffer(&nodes_, *begin++);
480  if (static_cast<size_t>(num_nodes) != nodes_.Size()) {
481  throw std::runtime_error("Could not load the correct number of nodes");
482  }
483  InitArrayFromPyBuffer(&leaf_vector_, *begin++);
484  InitArrayFromPyBuffer(&leaf_vector_offset_, *begin++);
485  InitArrayFromPyBuffer(&matching_categories_, *begin++);
486  InitArrayFromPyBuffer(&matching_categories_offset_, *begin++);
487 }
488 
489 template <typename ThresholdType, typename LeafOutputType>
491  cleft_ = cright_ = -1;
492  sindex_ = 0;
493  info_.leaf_value = static_cast<LeafOutputType>(0);
494  info_.threshold = static_cast<ThresholdType>(0);
495  data_count_ = 0;
496  sum_hess_ = gain_ = 0.0;
497  data_count_present_ = sum_hess_present_ = gain_present_ = false;
498  categories_list_right_child_ = false;
499  split_type_ = SplitFeatureType::kNone;
500  cmp_ = Operator::kNone;
501 }
502 
503 template <typename ThresholdType, typename LeafOutputType>
504 inline int
506  int nd = num_nodes++;
507  if (nodes_.Size() != static_cast<size_t>(nd)) {
508  throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes");
509  }
510  for (int nid = nd; nid < num_nodes; ++nid) {
511  leaf_vector_offset_.PushBack(leaf_vector_offset_.Back());
512  matching_categories_offset_.PushBack(matching_categories_offset_.Back());
513  nodes_.Resize(nodes_.Size() + 1);
514  nodes_.Back().Init();
515  }
516  return nd;
517 }
518 
519 template <typename ThresholdType, typename LeafOutputType>
520 inline void
522  num_nodes = 1;
523  leaf_vector_.Clear();
524  leaf_vector_offset_.Resize(2, 0);
525  matching_categories_.Clear();
526  matching_categories_offset_.Resize(2, 0);
527  nodes_.Resize(1);
528  nodes_.at(0).Init();
529  SetLeaf(0, static_cast<LeafOutputType>(0));
530 }
531 
532 template <typename ThresholdType, typename LeafOutputType>
533 inline void
535  const int cleft = this->AllocNode();
536  const int cright = this->AllocNode();
537  nodes_.at(nid).cleft_ = cleft;
538  nodes_.at(nid).cright_ = cright;
539 }
540 
541 template <typename ThresholdType, typename LeafOutputType>
542 inline std::vector<unsigned>
544  std::unordered_map<unsigned, bool> tmp;
545  for (int nid = 0; nid < num_nodes; ++nid) {
546  const SplitFeatureType type = SplitType(nid);
547  if (type != SplitFeatureType::kNone) {
548  const bool flag = (type == SplitFeatureType::kCategorical);
549  const uint32_t split_index = SplitIndex(nid);
550  if (tmp.count(split_index) == 0) {
551  tmp[split_index] = flag;
552  } else {
553  if (tmp[split_index] != flag) {
554  throw std::runtime_error("Feature " + std::to_string(split_index) +
555  " cannot be simultaneously be categorical and numerical.");
556  }
557  }
558  }
559  }
560  std::vector<unsigned> result;
561  for (const auto& kv : tmp) {
562  if (kv.second) {
563  result.push_back(kv.first);
564  }
565  }
566  std::sort(result.begin(), result.end());
567  return result;
568 }
569 
570 template <typename ThresholdType, typename LeafOutputType>
571 inline void
573  int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
574  Node& node = nodes_.at(nid);
575  if (split_index >= ((1U << 31U) - 1)) {
576  throw std::runtime_error("split_index too big");
577  }
578  if (default_left) split_index |= (1U << 31U);
579  node.sindex_ = split_index;
580  (node.info_).threshold = threshold;
581  node.cmp_ = cmp;
582  node.split_type_ = SplitFeatureType::kNumerical;
583  node.categories_list_right_child_ = false;
584 }
585 
586 template <typename ThresholdType, typename LeafOutputType>
587 inline void
589  int nid, unsigned split_index, bool default_left,
590  const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
591  if (split_index >= ((1U << 31U) - 1)) {
592  throw std::runtime_error("split_index too big");
593  }
594 
595  const size_t end_oft = matching_categories_offset_.Back();
596  const size_t new_end_oft = end_oft + categories_list.size();
597  if (end_oft != matching_categories_.Size()) {
598  throw std::runtime_error("Invariant violated");
599  }
600  if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
601  [end_oft](size_t x) { return (x == end_oft); })) {
602  throw std::runtime_error("Invariant violated");
603  }
604  // Hopefully we won't have to move any element as we add node_matching_categories for node nid
605  matching_categories_.Extend(categories_list);
606  if (new_end_oft != matching_categories_.Size()) {
607  throw std::runtime_error("Invariant violated");
608  }
609  std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
610  [new_end_oft](size_t& x) { x = new_end_oft; });
611  std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
612 
613  Node& node = nodes_.at(nid);
614  if (default_left) split_index |= (1U << 31U);
615  node.sindex_ = split_index;
616  node.split_type_ = SplitFeatureType::kCategorical;
617  node.categories_list_right_child_ = categories_list_right_child;
618 }
619 
620 template <typename ThresholdType, typename LeafOutputType>
621 inline void
622 Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
623  Node& node = nodes_.at(nid);
624  (node.info_).leaf_value = value;
625  node.cleft_ = -1;
626  node.cright_ = -1;
627  node.split_type_ = SplitFeatureType::kNone;
628 }
629 
630 template <typename ThresholdType, typename LeafOutputType>
631 inline void
633  int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
634  const size_t end_oft = leaf_vector_offset_.Back();
635  const size_t new_end_oft = end_oft + node_leaf_vector.size();
636  if (end_oft != leaf_vector_.Size()) {
637  throw std::runtime_error("Invariant violated");
638  }
639  if (!std::all_of(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
640  [end_oft](size_t x) { return (x == end_oft); })) {
641  throw std::runtime_error("Invariant violated");
642  }
643  // Hopefully we won't have to move any element as we add leaf vector elements for node nid
644  leaf_vector_.Extend(node_leaf_vector);
645  if (new_end_oft != leaf_vector_.Size()) {
646  throw std::runtime_error("Invariant violated");
647  }
648  std::for_each(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
649  [new_end_oft](size_t& x) { x = new_end_oft; });
650 
651  Node& node = nodes_.at(nid);
652  node.cleft_ = -1;
653  node.cright_ = -1;
654  node.split_type_ = SplitFeatureType::kNone;
655 }
656 
657 template <typename ThresholdType, typename LeafOutputType>
658 inline std::unique_ptr<Model>
659 Model::Create() {
660  std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
661  model->threshold_type_ = TypeToInfo<ThresholdType>();
662  model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
663  return model;
664 }
665 
666 template <typename ThresholdType, typename LeafOutputType>
668  public:
669  inline static std::unique_ptr<Model> Dispatch() {
670  return Model::Create<ThresholdType, LeafOutputType>();
671  }
672 };
673 
674 inline std::unique_ptr<Model>
675 Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
676  return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
677 }
678 
679 template <typename ThresholdType, typename LeafOutputType>
681  public:
682  template <typename Func>
683  inline static auto Dispatch(Model* model, Func func) {
684  return func(*dynamic_cast<ModelImpl<ThresholdType, LeafOutputType>*>(model));
685  }
686 
687  template <typename Func>
688  inline static auto Dispatch(const Model* model, Func func) {
689  return func(*dynamic_cast<const ModelImpl<ThresholdType, LeafOutputType>*>(model));
690  }
691 };
692 
693 template <typename Func>
694 inline auto
695 Model::Dispatch(Func func) {
696  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
697 }
698 
699 template <typename Func>
700 inline auto
701 Model::Dispatch(Func func) const {
702  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
703 }
704 
705 inline std::vector<PyBufferFrame>
706 Model::GetPyBuffer() {
707  std::vector<PyBufferFrame> buffer;
708  buffer.push_back(GetPyBufferFromScalar(&threshold_type_));
709  buffer.push_back(GetPyBufferFromScalar(&leaf_output_type_));
710  this->GetPyBuffer(&buffer);
711  return buffer;
712 }
713 
714 inline std::unique_ptr<Model>
715 Model::CreateFromPyBuffer(std::vector<PyBufferFrame> frames) {
716  TypeInfo threshold_type, leaf_output_type;
717  if (frames.size() < 2) {
718  throw std::runtime_error("Insufficient number of frames: there must be at least two");
719  }
720  InitScalarFromPyBuffer(&threshold_type, frames[0]);
721  InitScalarFromPyBuffer(&leaf_output_type, frames[1]);
722 
723  std::unique_ptr<Model> model = Model::Create(threshold_type, leaf_output_type);
724  model->InitFromPyBuffer(frames.begin() + 2, frames.end());
725  return model;
726 }
727 
728 
729 template <typename ThresholdType, typename LeafOutputType>
730 inline void
731 ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
732  /* Header */
733  dest->push_back(GetPyBufferFromScalar(&num_feature));
734  dest->push_back(GetPyBufferFromScalar(&task_type));
735  dest->push_back(GetPyBufferFromScalar(&average_tree_output));
736  dest->push_back(GetPyBufferFromScalar(&task_param, "T{=B=?xx=I=I}"));
737  dest->push_back(GetPyBufferFromScalar(
738  &param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}"));
739 
740  /* Body */
741  for (Tree<ThresholdType, LeafOutputType>& tree : trees) {
742  tree.GetPyBuffer(dest);
743  }
744 }
745 
746 template <typename ThresholdType, typename LeafOutputType>
747 inline void
749  std::vector<PyBufferFrame>::iterator begin, std::vector<PyBufferFrame>::iterator end) {
750  const size_t num_frame = std::distance(begin, end);
751  /* Header */
752  constexpr size_t kNumFrameInHeader = 5;
753  if (num_frame < kNumFrameInHeader) {
754  throw std::runtime_error("Wrong number of frames");
755  }
756  InitScalarFromPyBuffer(&num_feature, *begin++);
757  InitScalarFromPyBuffer(&task_type, *begin++);
758  InitScalarFromPyBuffer(&average_tree_output, *begin++);
759  InitScalarFromPyBuffer(&task_param, *begin++);
760  InitScalarFromPyBuffer(&param, *begin++);
761  /* Body */
762  if ((num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) {
763  throw std::runtime_error("Wrong number of frames");
764  }
765  trees.clear();
766  for (; begin < end; begin += kNumFramePerTree) {
767  trees.emplace_back();
768  trees.back().InitFromPyBuffer(begin, begin + kNumFramePerTree);
769  }
770 }
771 
772 inline void InitParamAndCheck(ModelParam* param,
773  const std::vector<std::pair<std::string, std::string>>& cfg) {
774  auto unknown = param->InitAllowUnknown(cfg);
775  if (!unknown.empty()) {
776  std::ostringstream oss;
777  for (const auto& kv : unknown) {
778  oss << kv.first << ", ";
779  }
780  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
781  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
782  }
783 }
784 
785 } // namespace treelite
786 #endif // TREELITE_TREE_IMPL_H_
SplitFeatureType
feature split type
Definition: base.h:22
void Init()
initialize the model with a single root node
Definition: tree_impl.h:521
TaskType
Enum type representing the task type.
Definition: tree.h:94
in-memory representation of a decision tree
Definition: tree.h:191
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
thin wrapper for tree ensemble model
Definition: tree.h:615
Operator
comparison operators
Definition: base.h:26