Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <algorithm>
11 #include <limits>
12 #include <memory>
13 #include <map>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 #include <unordered_map>
18 #include <sstream>
19 #include <iomanip>
20 #include <typeinfo>
21 #include <stdexcept>
22 #include <iostream>
23 #include <cstddef>
24 
25 namespace {
26 
27 template <typename T>
28 inline std::string GetString(T x) {
29  return std::to_string(x);
30 }
31 
32 template <>
33 inline std::string GetString<float>(float x) {
34  std::ostringstream oss;
35  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
36  return oss.str();
37 }
38 
39 template <>
40 inline std::string GetString<double>(double x) {
41  std::ostringstream oss;
42  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
43  return oss.str();
44 }
45 
46 } // anonymous namespace
47 
48 namespace treelite {
49 
50 template <typename T>
51 ContiguousArray<T>::ContiguousArray()
52  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
53 
54 template <typename T>
55 ContiguousArray<T>::~ContiguousArray() {
56  if (buffer_ && owned_buffer_) {
57  std::free(buffer_);
58  }
59 }
60 
61 template <typename T>
62 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
63  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
64  owned_buffer_(other.owned_buffer_) {
65  other.buffer_ = nullptr;
66  other.size_ = other.capacity_ = 0;
67 }
68 
69 template <typename T>
70 ContiguousArray<T>&
71 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
72  if (buffer_ && owned_buffer_) {
73  std::free(buffer_);
74  }
75  buffer_ = other.buffer_;
76  size_ = other.size_;
77  capacity_ = other.capacity_;
78  owned_buffer_ = other.owned_buffer_;
79  other.buffer_ = nullptr;
80  other.size_ = other.capacity_ = 0;
81  return *this;
82 }
83 
84 template <typename T>
85 inline ContiguousArray<T>
86 ContiguousArray<T>::Clone() const {
87  ContiguousArray clone;
88  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
89  if (!clone.buffer_) {
90  throw std::runtime_error("Could not allocate memory for the clone");
91  }
92  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
93  clone.size_ = size_;
94  clone.capacity_ = capacity_;
95  clone.owned_buffer_ = true;
96  return clone;
97 }
98 
99 template <typename T>
100 inline void
101 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, std::size_t size) {
102  if (buffer_ && owned_buffer_) {
103  std::free(buffer_);
104  }
105  buffer_ = static_cast<T*>(prealloc_buf);
106  size_ = size;
107  capacity_ = size;
108  owned_buffer_ = false;
109 }
110 
111 template <typename T>
112 inline T*
113 ContiguousArray<T>::Data() {
114  return buffer_;
115 }
116 
117 template <typename T>
118 inline const T*
119 ContiguousArray<T>::Data() const {
120  return buffer_;
121 }
122 
123 template <typename T>
124 inline T*
125 ContiguousArray<T>::End() {
126  return &buffer_[Size()];
127 }
128 
129 template <typename T>
130 inline const T*
131 ContiguousArray<T>::End() const {
132  return &buffer_[Size()];
133 }
134 
135 template <typename T>
136 inline T&
137 ContiguousArray<T>::Back() {
138  return buffer_[Size() - 1];
139 }
140 
141 template <typename T>
142 inline const T&
143 ContiguousArray<T>::Back() const {
144  return buffer_[Size() - 1];
145 }
146 
147 template <typename T>
148 inline std::size_t
149 ContiguousArray<T>::Size() const {
150  return size_;
151 }
152 
153 template <typename T>
154 inline bool
155 ContiguousArray<T>::Empty() const {
156  return (Size() == 0);
157 }
158 
159 template <typename T>
160 inline void
161 ContiguousArray<T>::Reserve(std::size_t newsize) {
162  if (!owned_buffer_) {
163  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
164  }
165  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
166  if (!newbuf) {
167  throw std::runtime_error("Could not expand buffer");
168  }
169  buffer_ = newbuf;
170  capacity_ = newsize;
171 }
172 
173 template <typename T>
174 inline void
175 ContiguousArray<T>::Resize(std::size_t newsize) {
176  if (!owned_buffer_) {
177  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
178  }
179  if (newsize > capacity_) {
180  std::size_t newcapacity = capacity_;
181  if (newcapacity == 0) {
182  newcapacity = 1;
183  }
184  while (newcapacity <= newsize) {
185  newcapacity *= 2;
186  }
187  Reserve(newcapacity);
188  }
189  size_ = newsize;
190 }
191 
192 template <typename T>
193 inline void
194 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
195  if (!owned_buffer_) {
196  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
197  }
198  std::size_t oldsize = Size();
199  Resize(newsize);
200  for (std::size_t i = oldsize; i < newsize; ++i) {
201  buffer_[i] = t;
202  }
203 }
204 
205 template <typename T>
206 inline void
207 ContiguousArray<T>::Clear() {
208  if (!owned_buffer_) {
209  throw std::runtime_error("Cannot clear when using a foreign buffer; clone first");
210  }
211  Resize(0);
212 }
213 
214 template <typename T>
215 inline void
216 ContiguousArray<T>::PushBack(T t) {
217  if (!owned_buffer_) {
218  throw std::runtime_error("Cannot add element when using a foreign buffer; clone first");
219  }
220  if (size_ == capacity_) {
221  Reserve(capacity_ * 2);
222  }
223  buffer_[size_++] = t;
224 }
225 
226 template <typename T>
227 inline void
228 ContiguousArray<T>::Extend(const std::vector<T>& other) {
229  if (!owned_buffer_) {
230  throw std::runtime_error("Cannot add elements when using a foreign buffer; clone first");
231  }
232  if (other.empty()) {
233  return; // appending an empty vector is a no-op
234  }
235  std::size_t newsize = size_ + other.size();
236  if (newsize > capacity_) {
237  std::size_t newcapacity = capacity_;
238  if (newcapacity == 0) {
239  newcapacity = 1;
240  }
241  while (newcapacity <= newsize) {
242  newcapacity *= 2;
243  }
244  Reserve(newcapacity);
245  }
246  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
247  size_ = newsize;
248 }
249 
250 template <typename T>
251 inline T&
252 ContiguousArray<T>::operator[](std::size_t idx) {
253  return buffer_[idx];
254 }
255 
256 template <typename T>
257 inline const T&
258 ContiguousArray<T>::operator[](std::size_t idx) const {
259  return buffer_[idx];
260 }
261 
262 template <typename T>
263 inline T&
264 ContiguousArray<T>::at(std::size_t idx) {
265  if (idx >= Size()) {
266  throw std::runtime_error("nid out of range");
267  }
268  return buffer_[idx];
269 }
270 
271 template <typename T>
272 inline const T&
273 ContiguousArray<T>::at(std::size_t idx) const {
274  if (idx >= Size()) {
275  throw std::runtime_error("nid out of range");
276  }
277  return buffer_[idx];
278 }
279 
280 template <typename T>
281 inline T&
282 ContiguousArray<T>::at(int idx) {
283  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
284  throw std::runtime_error("nid out of range");
285  }
286  return buffer_[static_cast<std::size_t>(idx)];
287 }
288 
289 template <typename T>
290 inline const T&
291 ContiguousArray<T>::at(int idx) const {
292  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
293  throw std::runtime_error("nid out of range");
294  }
295  return buffer_[static_cast<std::size_t>(idx)];
296 }
297 
298 template<typename Container>
299 inline std::vector<std::pair<std::string, std::string> >
300 ModelParam::InitAllowUnknown(const Container& kwargs) {
301  std::vector<std::pair<std::string, std::string>> unknowns;
302  for (const auto& e : kwargs) {
303  if (e.first == "pred_transform") {
304  std::strncpy(this->pred_transform, e.second.c_str(),
305  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
306  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
307  } else if (e.first == "sigmoid_alpha") {
308  this->sigmoid_alpha = std::stof(e.second, nullptr);
309  } else if (e.first == "ratio_c") {
310  this->ratio_c = std::stof(e.second, nullptr);
311  } else if (e.first == "global_bias") {
312  this->global_bias = std::stof(e.second, nullptr);
313  }
314  }
315  return unknowns;
316 }
317 
318 inline std::map<std::string, std::string>
319 ModelParam::__DICT__() const {
320  std::map<std::string, std::string> ret;
321  ret.emplace("pred_transform", std::string(this->pred_transform));
322  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
323  ret.emplace("ratio_c", GetString(this->ratio_c));
324  ret.emplace("global_bias", GetString(this->global_bias));
325  return ret;
326 }
327 
328 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
329  std::size_t itemsize, std::size_t nitem) {
330  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
331 }
332 
333 // Infer format string from data type
334 template <typename T>
335 inline const char* InferFormatString() {
336  switch (sizeof(T)) {
337  case 1:
338  return (std::is_unsigned<T>::value ? "=B" : "=b");
339  case 2:
340  return (std::is_unsigned<T>::value ? "=H" : "=h");
341  case 4:
342  if (std::is_integral<T>::value) {
343  return (std::is_unsigned<T>::value ? "=L" : "=l");
344  } else {
345  if (!std::is_floating_point<T>::value) {
346  throw std::runtime_error("Could not infer format string");
347  }
348  return "=f";
349  }
350  case 8:
351  if (std::is_integral<T>::value) {
352  return (std::is_unsigned<T>::value ? "=Q" : "=q");
353  } else {
354  if (!std::is_floating_point<T>::value) {
355  throw std::runtime_error("Could not infer format string");
356  }
357  return "=d";
358  }
359  default:
360  throw std::runtime_error("Unrecognized type");
361  }
362  return nullptr;
363 }
364 
365 template <typename T>
366 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
367  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
368 }
369 
370 template <typename T>
371 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
372  static_assert(std::is_arithmetic<T>::value,
373  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
374  return GetPyBufferFromArray(vec, InferFormatString<T>());
375 }
376 
377 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, std::size_t itemsize) {
378  return GetPyBufferFromArray(data, format, itemsize, 1);
379 }
380 
381 template <typename T>
382 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
383  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
384  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
385 }
386 
387 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
388  using T = std::underlying_type<TypeInfo>::type;
389  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
390 }
391 
392 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
393  using T = std::underlying_type<TaskType>::type;
394  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
395 }
396 
397 template <typename T>
398 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
399  static_assert(std::is_arithmetic<T>::value,
400  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
401  "specify format string manually");
402  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
403 }
404 
405 template <typename T>
406 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
407  if (sizeof(T) != frame.itemsize) {
408  throw std::runtime_error("Incorrect itemsize");
409  }
410  vec->UseForeignBuffer(frame.buf, frame.nitem);
411 }
412 
413 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
414  using T = std::underlying_type<TypeInfo>::type;
415  if (sizeof(T) != buffer.itemsize) {
416  throw std::runtime_error("Incorrect itemsize");
417  }
418  if (buffer.nitem != 1) {
419  throw std::runtime_error("nitem must be 1 for a scalar");
420  }
421  T* t = static_cast<T*>(buffer.buf);
422  *scalar = static_cast<TypeInfo>(*t);
423 }
424 
425 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
426  using T = std::underlying_type<TaskType>::type;
427  if (sizeof(T) != buffer.itemsize) {
428  throw std::runtime_error("Incorrect itemsize");
429  }
430  if (buffer.nitem != 1) {
431  throw std::runtime_error("nitem must be 1 for a scalar");
432  }
433  T* t = static_cast<T*>(buffer.buf);
434  *scalar = static_cast<TaskType>(*t);
435 }
436 
437 template <typename T>
438 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
439  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
440  if (sizeof(T) != buffer.itemsize) {
441  throw std::runtime_error("Incorrect itemsize");
442  }
443  if (buffer.nitem != 1) {
444  throw std::runtime_error("nitem must be 1 for a scalar");
445  }
446  T* t = static_cast<T*>(buffer.buf);
447  *scalar = *t;
448 }
449 
450 template <typename T>
451 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
452  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
453  if (std::fread(scalar, sizeof(T), 1, fp) < 1) {
454  throw std::runtime_error("Could not read a scalar");
455  }
456 }
457 
458 template <typename T>
459 inline void WriteScalarToFile(T* scalar, FILE* fp) {
460  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
461  if (std::fwrite(scalar, sizeof(T), 1, fp) < 1) {
462  throw std::runtime_error("Could not write a scalar");
463  }
464 }
465 
466 template <typename T>
467 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
468  uint64_t nelem;
469  if (std::fread(&nelem, sizeof(nelem), 1, fp) < 1) {
470  throw std::runtime_error("Could not read the number of elements");
471  }
472  vec->Clear();
473  vec->Resize(nelem);
474  if (nelem == 0) {
475  return; // handle empty arrays
476  }
477  const auto nelem_size_t = static_cast<std::size_t>(nelem);
478  if (std::fread(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
479  throw std::runtime_error("Could not read an array");
480  }
481 }
482 
483 template <typename T>
484 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
485  static_assert(sizeof(uint64_t) >= sizeof(size_t), "size_t too large");
486  const auto nelem = static_cast<uint64_t>(vec->Size());
487  if (std::fwrite(&nelem, sizeof(nelem), 1, fp) < 1) {
488  throw std::runtime_error("Could not write the number of elements");
489  }
490  if (nelem == 0) {
491  return; // handle empty arrays
492  }
493  const auto nelem_size_t = vec->Size();
494  if (std::fwrite(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
495  throw std::runtime_error("Could not write an array");
496  }
497 }
498 
499 template <typename ThresholdType, typename LeafOutputType>
500 inline Tree<ThresholdType, LeafOutputType>
501 Tree<ThresholdType, LeafOutputType>::Clone() const {
502  Tree<ThresholdType, LeafOutputType> tree;
503  tree.num_nodes = num_nodes;
504  tree.nodes_ = nodes_.Clone();
505  tree.leaf_vector_ = leaf_vector_.Clone();
506  tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
507  tree.leaf_vector_end_ = leaf_vector_end_.Clone();
508  tree.matching_categories_ = matching_categories_.Clone();
509  tree.matching_categories_offset_ = matching_categories_offset_.Clone();
510  return tree;
511 }
512 
513 template <typename ThresholdType, typename LeafOutputType>
514 inline const char*
515 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
516  if (std::is_same<ThresholdType, float>::value) {
517  return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
518  } else {
519  return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
520  }
521 }
522 
523 constexpr std::size_t kNumFramePerTree = 8;
524 
525 template <typename ThresholdType, typename LeafOutputType>
526 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
527 inline void
528 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
529  ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
530  CompositeArrayHandler composite_array_handler) {
531  scalar_handler(&num_nodes);
532  scalar_handler(&has_categorical_split_);
533  composite_array_handler(&nodes_, GetFormatStringForNode());
534  primitive_array_handler(&leaf_vector_);
535  primitive_array_handler(&leaf_vector_begin_);
536  primitive_array_handler(&leaf_vector_end_);
537  primitive_array_handler(&matching_categories_);
538  primitive_array_handler(&matching_categories_offset_);
539 }
540 
541 template <typename ThresholdType, typename LeafOutputType>
542 template <typename ScalarHandler, typename ArrayHandler>
543 inline void
544 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
545  ScalarHandler scalar_handler, ArrayHandler array_handler) {
546  scalar_handler(&num_nodes);
547  scalar_handler(&has_categorical_split_);
548  array_handler(&nodes_);
549  if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
550  throw std::runtime_error("Could not load the correct number of nodes");
551  }
552  array_handler(&leaf_vector_);
553  array_handler(&leaf_vector_begin_);
554  array_handler(&leaf_vector_end_);
555  array_handler(&matching_categories_);
556  array_handler(&matching_categories_offset_);
557 }
558 
559 template <typename ThresholdType, typename LeafOutputType>
560 inline void
561 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
562  auto scalar_handler = [dest](auto* field) {
563  dest->push_back(GetPyBufferFromScalar(field));
564  };
565  auto primitive_array_handler = [dest](auto* field) {
566  dest->push_back(GetPyBufferFromArray(field));
567  };
568  auto composite_array_handler = [dest](auto* field, const char* format) {
569  dest->push_back(GetPyBufferFromArray(field, format));
570  };
571  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
572 }
573 
574 template <typename ThresholdType, typename LeafOutputType>
575 inline void
576 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
577  auto scalar_handler = [dest_fp](auto* field) {
578  WriteScalarToFile(field, dest_fp);
579  };
580  auto primitive_array_handler = [dest_fp](auto* field) {
581  WriteArrayToFile(field, dest_fp);
582  };
583  auto composite_array_handler = [dest_fp](auto* field, const char* format) {
584  WriteArrayToFile(field, dest_fp);
585  };
586  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
587 }
588 
589 template <typename ThresholdType, typename LeafOutputType>
590 inline void
591 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
592  std::vector<PyBufferFrame>::iterator end) {
593  if (std::distance(begin, end) != kNumFramePerTree) {
594  throw std::runtime_error("Wrong number of frames specified");
595  }
596  auto scalar_handler = [&begin](auto* field) {
597  InitScalarFromPyBuffer(field, *begin++);
598  };
599  auto array_handler = [&begin](auto* field) {
600  InitArrayFromPyBuffer(field, *begin++);
601  };
602  DeserializeTemplate(scalar_handler, array_handler);
603 }
604 
605 template <typename ThresholdType, typename LeafOutputType>
606 inline void
607 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
608  auto scalar_handler = [src_fp](auto* field) {
609  ReadScalarFromFile(field, src_fp);
610  };
611  auto array_handler = [src_fp](auto* field) {
612  ReadArrayFromFile(field, src_fp);
613  };
614  DeserializeTemplate(scalar_handler, array_handler);
615 }
616 
617 template <typename ThresholdType, typename LeafOutputType>
619  std::memset(this, 0, sizeof(Node));
620  cleft_ = cright_ = -1;
621  sindex_ = 0;
622  info_.leaf_value = static_cast<LeafOutputType>(0);
623  info_.threshold = static_cast<ThresholdType>(0);
624  data_count_ = 0;
625  sum_hess_ = gain_ = 0.0;
626  data_count_present_ = sum_hess_present_ = gain_present_ = false;
627  categories_list_right_child_ = false;
628  split_type_ = SplitFeatureType::kNone;
629  cmp_ = Operator::kNone;
630 }
631 
632 template <typename ThresholdType, typename LeafOutputType>
633 inline int
635  int nd = num_nodes++;
636  if (nodes_.Size() != static_cast<std::size_t>(nd)) {
637  throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes");
638  }
639  for (int nid = nd; nid < num_nodes; ++nid) {
640  leaf_vector_begin_.PushBack(0);
641  leaf_vector_end_.PushBack(0);
642  matching_categories_offset_.PushBack(matching_categories_offset_.Back());
643  nodes_.Resize(nodes_.Size() + 1);
644  nodes_.Back().Init();
645  }
646  return nd;
647 }
648 
649 template <typename ThresholdType, typename LeafOutputType>
650 inline void
652  num_nodes = 1;
653  has_categorical_split_ = false;
654  leaf_vector_.Clear();
655  leaf_vector_begin_.Resize(1, {});
656  leaf_vector_end_.Resize(1, {});
657  matching_categories_.Clear();
658  matching_categories_offset_.Resize(2, 0);
659  nodes_.Resize(1);
660  nodes_.at(0).Init();
661  SetLeaf(0, static_cast<LeafOutputType>(0));
662 }
663 
664 template <typename ThresholdType, typename LeafOutputType>
665 inline void
667  const int cleft = this->AllocNode();
668  const int cright = this->AllocNode();
669  nodes_.at(nid).cleft_ = cleft;
670  nodes_.at(nid).cright_ = cright;
671 }
672 
673 template <typename ThresholdType, typename LeafOutputType>
674 inline void
676  int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
677  Node& node = nodes_.at(nid);
678  if (split_index >= ((1U << 31U) - 1)) {
679  throw std::runtime_error("split_index too big");
680  }
681  if (default_left) split_index |= (1U << 31U);
682  node.sindex_ = split_index;
683  (node.info_).threshold = threshold;
684  node.cmp_ = cmp;
685  node.split_type_ = SplitFeatureType::kNumerical;
686  node.categories_list_right_child_ = false;
687 }
688 
689 template <typename ThresholdType, typename LeafOutputType>
690 inline void
692  int nid, unsigned split_index, bool default_left,
693  const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
694  if (split_index >= ((1U << 31U) - 1)) {
695  throw std::runtime_error("split_index too big");
696  }
697 
698  const std::size_t end_oft = matching_categories_offset_.Back();
699  const std::size_t new_end_oft = end_oft + categories_list.size();
700  if (end_oft != matching_categories_.Size()) {
701  throw std::runtime_error("Invariant violated");
702  }
703  if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
704  [end_oft](std::size_t x) { return (x == end_oft); })) {
705  throw std::runtime_error("Invariant violated");
706  }
707  // Hopefully we won't have to move any element as we add node_matching_categories for node nid
708  matching_categories_.Extend(categories_list);
709  if (new_end_oft != matching_categories_.Size()) {
710  throw std::runtime_error("Invariant violated");
711  }
712  std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
713  [new_end_oft](std::size_t& x) { x = new_end_oft; });
714  if (!matching_categories_.Empty()) {
715  std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
716  }
717 
718  Node& node = nodes_.at(nid);
719  if (default_left) split_index |= (1U << 31U);
720  node.sindex_ = split_index;
721  node.split_type_ = SplitFeatureType::kCategorical;
722  node.categories_list_right_child_ = categories_list_right_child;
723 
724  has_categorical_split_ = true;
725 }
726 
727 template <typename ThresholdType, typename LeafOutputType>
728 inline void
729 Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
730  Node& node = nodes_.at(nid);
731  (node.info_).leaf_value = value;
732  node.cleft_ = -1;
733  node.cright_ = -1;
734  node.split_type_ = SplitFeatureType::kNone;
735 }
736 
737 template <typename ThresholdType, typename LeafOutputType>
738 inline void
740  int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
741  std::size_t begin = leaf_vector_.Size();
742  std::size_t end = begin + node_leaf_vector.size();
743  leaf_vector_.Extend(node_leaf_vector);
744  leaf_vector_begin_[nid] = begin;
745  leaf_vector_end_[nid] = end;
746  Node &node = nodes_.at(nid);
747  node.cleft_ = -1;
748  node.cright_ = -1;
749  node.split_type_ = SplitFeatureType::kNone;
750 }
751 
752 template <typename ThresholdType, typename LeafOutputType>
753 inline std::unique_ptr<Model>
754 Model::Create() {
755  std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
756  model->threshold_type_ = TypeToInfo<ThresholdType>();
757  model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
758  return model;
759 }
760 
761 template <typename ThresholdType, typename LeafOutputType>
763  public:
764  inline static std::unique_ptr<Model> Dispatch() {
765  return Model::Create<ThresholdType, LeafOutputType>();
766  }
767 };
768 
769 inline std::unique_ptr<Model>
770 Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
771  return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
772 }
773 
774 template <typename ThresholdType, typename LeafOutputType>
776  public:
777  template <typename Func>
778  inline static auto Dispatch(Model* model, Func func) {
779  return func(*dynamic_cast<ModelImpl<ThresholdType, LeafOutputType>*>(model));
780  }
781 
782  template <typename Func>
783  inline static auto Dispatch(const Model* model, Func func) {
784  return func(*dynamic_cast<const ModelImpl<ThresholdType, LeafOutputType>*>(model));
785  }
786 };
787 
788 template <typename Func>
789 inline auto
790 Model::Dispatch(Func func) {
791  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
792 }
793 
794 template <typename Func>
795 inline auto
796 Model::Dispatch(Func func) const {
797  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
798 }
799 
800 template <typename HeaderPrimitiveFieldHandlerFunc>
801 inline void
802 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
803  header_primitive_field_handler(&major_ver_);
804  header_primitive_field_handler(&minor_ver_);
805  header_primitive_field_handler(&patch_ver_);
806  header_primitive_field_handler(&threshold_type_);
807  header_primitive_field_handler(&leaf_output_type_);
808 }
809 
810 template <typename HeaderPrimitiveFieldHandlerFunc>
811 inline void
812 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
813  TypeInfo& threshold_type, TypeInfo& leaf_output_type) {
814  int major_ver, minor_ver, patch_ver;
815  header_primitive_field_handler(&major_ver);
816  header_primitive_field_handler(&minor_ver);
817  header_primitive_field_handler(&patch_ver);
818  if (major_ver != TREELITE_VER_MAJOR || minor_ver != TREELITE_VER_MINOR) {
819  std::ostringstream oss;
820  oss << "Cannot deserialize model from a different version of Treelite." << std::endl
821  << "Currently running Treelite version " << TREELITE_VER_MAJOR << "."
822  << TREELITE_VER_MINOR << "." << TREELITE_VER_PATCH << std::endl
823  << "The model checkpoint was generated from Treelite version " << major_ver << "."
824  << minor_ver << "." << patch_ver;
825  throw std::runtime_error(oss.str());
826  }
827  header_primitive_field_handler(&threshold_type);
828  header_primitive_field_handler(&leaf_output_type);
829 }
830 
831 template <typename ThresholdType, typename LeafOutputType>
832 template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
833  typename TreeHandlerFunc>
834 inline void
836  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
837  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
838  TreeHandlerFunc tree_handler) {
839  /* Header */
840  header_primitive_field_handler(&num_feature);
841  header_primitive_field_handler(&task_type);
842  header_primitive_field_handler(&average_tree_output);
843  header_composite_field_handler(&task_param, "T{=B=?xx=I=I}");
844  header_composite_field_handler(
845  &param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f=f}");
846 
847  /* Body */
848  for (Tree<ThresholdType, LeafOutputType>& tree : trees) {
849  tree_handler(tree);
850  }
851 }
852 
853 template <typename ThresholdType, typename LeafOutputType>
854 template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
855 inline void
857  std::size_t num_tree,
858  HeaderFieldHandlerFunc header_field_handler,
859  TreeHandlerFunc tree_handler) {
860  /* Header */
861  header_field_handler(&num_feature);
862  header_field_handler(&task_type);
863  header_field_handler(&average_tree_output);
864  header_field_handler(&task_param);
865  header_field_handler(&param);
866  /* Body */
867  trees.clear();
868  for (std::size_t i = 0; i < num_tree; ++i) {
869  trees.emplace_back();
870  tree_handler(trees.back());
871  }
872 }
873 
874 template <typename ThresholdType, typename LeafOutputType>
875 inline void
876 ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
877  auto header_primitive_field_handler = [dest](auto* field) {
878  dest->push_back(GetPyBufferFromScalar(field));
879  };
880  auto header_composite_field_handler = [dest](auto* field, const char* format) {
881  dest->push_back(GetPyBufferFromScalar(field, format));
882  };
883  auto tree_handler = [dest](Tree<ThresholdType, LeafOutputType>& tree) {
884  tree.GetPyBuffer(dest);
885  };
886  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
887 }
888 
889 template <typename ThresholdType, typename LeafOutputType>
890 inline void
892  const auto num_tree = static_cast<uint64_t>(this->trees.size());
893  WriteScalarToFile(&num_tree, dest_fp);
894  auto header_primitive_field_handler = [dest_fp](auto* field) {
895  WriteScalarToFile(field, dest_fp);
896  };
897  auto header_composite_field_handler = [dest_fp](auto* field, const char* format) {
898  WriteScalarToFile(field, dest_fp);
899  };
900  auto tree_handler = [dest_fp](Tree<ThresholdType, LeafOutputType>& tree) {
901  tree.SerializeToFile(dest_fp);
902  };
903  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
904 }
905 
906 template <typename ThresholdType, typename LeafOutputType>
907 inline void
909  std::vector<PyBufferFrame>::iterator begin, std::vector<PyBufferFrame>::iterator end) {
910  const std::size_t num_frame = std::distance(begin, end);
911  constexpr std::size_t kNumFrameInHeader = 5;
912  if (num_frame < kNumFrameInHeader || (num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) {
913  throw std::runtime_error("Wrong number of frames");
914  }
915  const std::size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
916 
917  auto header_field_handler = [&begin](auto* field) {
918  InitScalarFromPyBuffer(field, *begin++);
919  };
920 
921  auto tree_handler = [&begin](Tree<ThresholdType, LeafOutputType>& tree) {
922  // Read the frames in the range [begin, begin + kNumFramePerTree) into the tree
923  tree.InitFromPyBuffer(begin, begin + kNumFramePerTree);
924  begin += kNumFramePerTree;
925  // Advance the iterator so that the next tree reads the next kNumFramePerTree frames
926  };
927 
928  DeserializeTemplate(num_tree, header_field_handler, tree_handler);
929 }
930 
931 template <typename ThresholdType, typename LeafOutputType>
932 inline void
934  uint64_t num_tree;
935  ReadScalarFromFile(&num_tree, src_fp);
936 
937  auto header_field_handler = [src_fp](auto* field) {
938  ReadScalarFromFile(field, src_fp);
939  };
940 
941  auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
942  tree.DeserializeFromFile(src_fp);
943  };
944 
945  DeserializeTemplate(num_tree, header_field_handler, tree_handler);
946 }
947 
948 inline void InitParamAndCheck(ModelParam* param,
949  const std::vector<std::pair<std::string, std::string>>& cfg) {
950  auto unknown = param->InitAllowUnknown(cfg);
951  if (!unknown.empty()) {
952  std::ostringstream oss;
953  for (const auto& kv : unknown) {
954  oss << kv.first << ", ";
955  }
956  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
957  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
958  }
959 }
960 
961 } // namespace treelite
962 #endif // TREELITE_TREE_IMPL_H_
void Init()
initialize the model with a single root node
Definition: tree_impl.h:651
TaskType
Enum type representing the task type.
Definition: tree.h:100
tree node
Definition: tree.h:218
in-memory representation of a decision tree
Definition: tree.h:215
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
thin wrapper for tree ensemble model
Definition: tree.h:717
Operator
comparison operators
Definition: base.h:26