Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <algorithm>
11 #include <limits>
12 #include <memory>
13 #include <map>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 #include <unordered_map>
18 #include <sstream>
19 #include <iomanip>
20 #include <typeinfo>
21 #include <stdexcept>
22 #include <iostream>
23 #include <cstddef>
24 
25 namespace {
26 
27 template <typename T>
28 inline std::string GetString(T x) {
29  return std::to_string(x);
30 }
31 
32 template <>
33 inline std::string GetString<float>(float x) {
34  std::ostringstream oss;
35  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
36  return oss.str();
37 }
38 
39 template <>
40 inline std::string GetString<double>(double x) {
41  std::ostringstream oss;
42  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
43  return oss.str();
44 }
45 
46 } // anonymous namespace
47 
48 namespace treelite {
49 
50 template <typename T>
51 ContiguousArray<T>::ContiguousArray()
52  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
53 
54 template <typename T>
55 ContiguousArray<T>::~ContiguousArray() {
56  if (buffer_ && owned_buffer_) {
57  std::free(buffer_);
58  }
59 }
60 
61 template <typename T>
62 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
63  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
64  owned_buffer_(other.owned_buffer_) {
65  other.buffer_ = nullptr;
66  other.size_ = other.capacity_ = 0;
67 }
68 
69 template <typename T>
70 ContiguousArray<T>&
71 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
72  if (buffer_ && owned_buffer_) {
73  std::free(buffer_);
74  }
75  buffer_ = other.buffer_;
76  size_ = other.size_;
77  capacity_ = other.capacity_;
78  owned_buffer_ = other.owned_buffer_;
79  other.buffer_ = nullptr;
80  other.size_ = other.capacity_ = 0;
81  return *this;
82 }
83 
84 template <typename T>
85 inline ContiguousArray<T>
86 ContiguousArray<T>::Clone() const {
87  ContiguousArray clone;
88  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
89  if (!clone.buffer_) {
90  throw std::runtime_error("Could not allocate memory for the clone");
91  }
92  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
93  clone.size_ = size_;
94  clone.capacity_ = capacity_;
95  clone.owned_buffer_ = true;
96  return clone;
97 }
98 
99 template <typename T>
100 inline void
101 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, std::size_t size) {
102  if (buffer_ && owned_buffer_) {
103  std::free(buffer_);
104  }
105  buffer_ = static_cast<T*>(prealloc_buf);
106  size_ = size;
107  capacity_ = size;
108  owned_buffer_ = false;
109 }
110 
111 template <typename T>
112 inline T*
113 ContiguousArray<T>::Data() {
114  return buffer_;
115 }
116 
117 template <typename T>
118 inline const T*
119 ContiguousArray<T>::Data() const {
120  return buffer_;
121 }
122 
123 template <typename T>
124 inline T*
125 ContiguousArray<T>::End() {
126  return &buffer_[Size()];
127 }
128 
129 template <typename T>
130 inline const T*
131 ContiguousArray<T>::End() const {
132  return &buffer_[Size()];
133 }
134 
135 template <typename T>
136 inline T&
137 ContiguousArray<T>::Back() {
138  return buffer_[Size() - 1];
139 }
140 
141 template <typename T>
142 inline const T&
143 ContiguousArray<T>::Back() const {
144  return buffer_[Size() - 1];
145 }
146 
147 template <typename T>
148 inline std::size_t
149 ContiguousArray<T>::Size() const {
150  return size_;
151 }
152 
153 template <typename T>
154 inline void
155 ContiguousArray<T>::Reserve(std::size_t newsize) {
156  if (!owned_buffer_) {
157  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
158  }
159  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
160  if (!newbuf) {
161  throw std::runtime_error("Could not expand buffer");
162  }
163  buffer_ = newbuf;
164  capacity_ = newsize;
165 }
166 
167 template <typename T>
168 inline void
169 ContiguousArray<T>::Resize(std::size_t newsize) {
170  if (!owned_buffer_) {
171  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
172  }
173  if (newsize > capacity_) {
174  std::size_t newcapacity = capacity_;
175  if (newcapacity == 0) {
176  newcapacity = 1;
177  }
178  while (newcapacity <= newsize) {
179  newcapacity *= 2;
180  }
181  Reserve(newcapacity);
182  }
183  size_ = newsize;
184 }
185 
186 template <typename T>
187 inline void
188 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
189  if (!owned_buffer_) {
190  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
191  }
192  std::size_t oldsize = Size();
193  Resize(newsize);
194  for (std::size_t i = oldsize; i < newsize; ++i) {
195  buffer_[i] = t;
196  }
197 }
198 
199 template <typename T>
200 inline void
201 ContiguousArray<T>::Clear() {
202  if (!owned_buffer_) {
203  throw std::runtime_error("Cannot clear when using a foreign buffer; clone first");
204  }
205  Resize(0);
206 }
207 
208 template <typename T>
209 inline void
210 ContiguousArray<T>::PushBack(T t) {
211  if (!owned_buffer_) {
212  throw std::runtime_error("Cannot add element when using a foreign buffer; clone first");
213  }
214  if (size_ == capacity_) {
215  Reserve(capacity_ * 2);
216  }
217  buffer_[size_++] = t;
218 }
219 
220 template <typename T>
221 inline void
222 ContiguousArray<T>::Extend(const std::vector<T>& other) {
223  if (!owned_buffer_) {
224  throw std::runtime_error("Cannot add elements when using a foreign buffer; clone first");
225  }
226  std::size_t newsize = size_ + other.size();
227  if (newsize > capacity_) {
228  std::size_t newcapacity = capacity_;
229  if (newcapacity == 0) {
230  newcapacity = 1;
231  }
232  while (newcapacity <= newsize) {
233  newcapacity *= 2;
234  }
235  Reserve(newcapacity);
236  }
237  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
238  size_ = newsize;
239 }
240 
241 template <typename T>
242 inline T&
243 ContiguousArray<T>::operator[](std::size_t idx) {
244  return buffer_[idx];
245 }
246 
247 template <typename T>
248 inline const T&
249 ContiguousArray<T>::operator[](std::size_t idx) const {
250  return buffer_[idx];
251 }
252 
253 template <typename T>
254 inline T&
255 ContiguousArray<T>::at(std::size_t idx) {
256  if (idx >= Size()) {
257  throw std::runtime_error("nid out of range");
258  }
259  return buffer_[idx];
260 }
261 
262 template <typename T>
263 inline const T&
264 ContiguousArray<T>::at(std::size_t idx) const {
265  if (idx >= Size()) {
266  throw std::runtime_error("nid out of range");
267  }
268  return buffer_[idx];
269 }
270 
271 template <typename T>
272 inline T&
273 ContiguousArray<T>::at(int idx) {
274  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
275  throw std::runtime_error("nid out of range");
276  }
277  return buffer_[static_cast<std::size_t>(idx)];
278 }
279 
280 template <typename T>
281 inline const T&
282 ContiguousArray<T>::at(int idx) const {
283  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
284  throw std::runtime_error("nid out of range");
285  }
286  return buffer_[static_cast<std::size_t>(idx)];
287 }
288 
289 template<typename Container>
290 inline std::vector<std::pair<std::string, std::string> >
291 ModelParam::InitAllowUnknown(const Container& kwargs) {
292  std::vector<std::pair<std::string, std::string>> unknowns;
293  for (const auto& e : kwargs) {
294  if (e.first == "pred_transform") {
295  std::strncpy(this->pred_transform, e.second.c_str(),
296  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
297  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
298  } else if (e.first == "sigmoid_alpha") {
299  this->sigmoid_alpha = dmlc::stof(e.second, nullptr);
300  } else if (e.first == "global_bias") {
301  this->global_bias = dmlc::stof(e.second, nullptr);
302  }
303  }
304  return unknowns;
305 }
306 
307 inline std::map<std::string, std::string>
308 ModelParam::__DICT__() const {
309  std::map<std::string, std::string> ret;
310  ret.emplace("pred_transform", std::string(this->pred_transform));
311  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
312  ret.emplace("global_bias", GetString(this->global_bias));
313  return ret;
314 }
315 
316 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
317  std::size_t itemsize, std::size_t nitem) {
318  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
319 }
320 
321 // Infer format string from data type
322 template <typename T>
323 inline const char* InferFormatString() {
324  switch (sizeof(T)) {
325  case 1:
326  return (std::is_unsigned<T>::value ? "=B" : "=b");
327  case 2:
328  return (std::is_unsigned<T>::value ? "=H" : "=h");
329  case 4:
330  if (std::is_integral<T>::value) {
331  return (std::is_unsigned<T>::value ? "=L" : "=l");
332  } else {
333  if (!std::is_floating_point<T>::value) {
334  throw std::runtime_error("Could not infer format string");
335  }
336  return "=f";
337  }
338  case 8:
339  if (std::is_integral<T>::value) {
340  return (std::is_unsigned<T>::value ? "=Q" : "=q");
341  } else {
342  if (!std::is_floating_point<T>::value) {
343  throw std::runtime_error("Could not infer format string");
344  }
345  return "=d";
346  }
347  default:
348  throw std::runtime_error("Unrecognized type");
349  }
350  return nullptr;
351 }
352 
353 template <typename T>
354 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
355  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
356 }
357 
358 template <typename T>
359 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
360  static_assert(std::is_arithmetic<T>::value,
361  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
362  return GetPyBufferFromArray(vec, InferFormatString<T>());
363 }
364 
365 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, std::size_t itemsize) {
366  return GetPyBufferFromArray(data, format, itemsize, 1);
367 }
368 
369 template <typename T>
370 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
371  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
372  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
373 }
374 
375 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
376  using T = std::underlying_type<TypeInfo>::type;
377  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
378 }
379 
380 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
381  using T = std::underlying_type<TaskType>::type;
382  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
383 }
384 
385 template <typename T>
386 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
387  static_assert(std::is_arithmetic<T>::value,
388  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
389  "specify format string manually");
390  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
391 }
392 
393 template <typename T>
394 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
395  if (sizeof(T) != frame.itemsize) {
396  throw std::runtime_error("Incorrect itemsize");
397  }
398  vec->UseForeignBuffer(frame.buf, frame.nitem);
399 }
400 
401 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
402  using T = std::underlying_type<TypeInfo>::type;
403  if (sizeof(T) != buffer.itemsize) {
404  throw std::runtime_error("Incorrect itemsize");
405  }
406  if (buffer.nitem != 1) {
407  throw std::runtime_error("nitem must be 1 for a scalar");
408  }
409  T* t = static_cast<T*>(buffer.buf);
410  *scalar = static_cast<TypeInfo>(*t);
411 }
412 
413 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
414  using T = std::underlying_type<TaskType>::type;
415  if (sizeof(T) != buffer.itemsize) {
416  throw std::runtime_error("Incorrect itemsize");
417  }
418  if (buffer.nitem != 1) {
419  throw std::runtime_error("nitem must be 1 for a scalar");
420  }
421  T* t = static_cast<T*>(buffer.buf);
422  *scalar = static_cast<TaskType>(*t);
423 }
424 
425 template <typename T>
426 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
427  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
428  if (sizeof(T) != buffer.itemsize) {
429  throw std::runtime_error("Incorrect itemsize");
430  }
431  if (buffer.nitem != 1) {
432  throw std::runtime_error("nitem must be 1 for a scalar");
433  }
434  T* t = static_cast<T*>(buffer.buf);
435  *scalar = *t;
436 }
437 
438 template <typename T>
439 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
440  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
441  if (std::fread(scalar, sizeof(T), 1, fp) < 1) {
442  throw std::runtime_error("Could not read a scalar");
443  }
444 }
445 
446 template <typename T>
447 inline void WriteScalarToFile(T* scalar, FILE* fp) {
448  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
449  if (std::fwrite(scalar, sizeof(T), 1, fp) < 1) {
450  throw std::runtime_error("Could not write a scalar");
451  }
452 }
453 
454 template <typename T>
455 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
456  uint64_t nelem;
457  if (std::fread(&nelem, sizeof(nelem), 1, fp) < 1) {
458  throw std::runtime_error("Could not read the number of elements");
459  }
460  vec->Clear();
461  vec->Resize(nelem);
462  const auto nelem_size_t = static_cast<std::size_t>(nelem);
463  if (std::fread(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
464  throw std::runtime_error("Could not read an array");
465  }
466 }
467 
468 template <typename T>
469 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
470  static_assert(sizeof(uint64_t) >= sizeof(size_t), "size_t too large");
471  const auto nelem = static_cast<uint64_t>(vec->Size());
472  if (std::fwrite(&nelem, sizeof(nelem), 1, fp) < 1) {
473  throw std::runtime_error("Could not write the number of elements");
474  }
475  const auto nelem_size_t = vec->Size();
476  if (std::fwrite(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
477  throw std::runtime_error("Could not write an array");
478  }
479 }
480 
481 template <typename ThresholdType, typename LeafOutputType>
482 inline Tree<ThresholdType, LeafOutputType>
483 Tree<ThresholdType, LeafOutputType>::Clone() const {
484  Tree<ThresholdType, LeafOutputType> tree;
485  tree.num_nodes = num_nodes;
486  tree.nodes_ = nodes_.Clone();
487  tree.leaf_vector_ = leaf_vector_.Clone();
488  tree.leaf_vector_offset_ = leaf_vector_offset_.Clone();
489  tree.matching_categories_ = matching_categories_.Clone();
490  tree.matching_categories_offset_ = matching_categories_offset_.Clone();
491  return tree;
492 }
493 
494 template <typename ThresholdType, typename LeafOutputType>
495 inline const char*
496 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
497  if (std::is_same<ThresholdType, float>::value) {
498  return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
499  } else {
500  return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
501  }
502 }
503 
504 constexpr std::size_t kNumFramePerTree = 6;
505 
506 template <typename ThresholdType, typename LeafOutputType>
507 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
508 inline void
509 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
510  ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
511  CompositeArrayHandler composite_array_handler) {
512  scalar_handler(&num_nodes);
513  composite_array_handler(&nodes_, GetFormatStringForNode());
514  primitive_array_handler(&leaf_vector_);
515  primitive_array_handler(&leaf_vector_offset_);
516  primitive_array_handler(&matching_categories_);
517  primitive_array_handler(&matching_categories_offset_);
518 }
519 
520 template <typename ThresholdType, typename LeafOutputType>
521 template <typename ScalarHandler, typename ArrayHandler>
522 inline void
523 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
524  ScalarHandler scalar_handler, ArrayHandler array_handler) {
525  scalar_handler(&num_nodes);
526  array_handler(&nodes_);
527  if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
528  throw std::runtime_error("Could not load the correct number of nodes");
529  }
530  array_handler(&leaf_vector_);
531  array_handler(&leaf_vector_offset_);
532  array_handler(&matching_categories_);
533  array_handler(&matching_categories_offset_);
534 }
535 
536 template <typename ThresholdType, typename LeafOutputType>
537 inline void
538 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
539  auto scalar_handler = [dest](auto* field) {
540  dest->push_back(GetPyBufferFromScalar(field));
541  };
542  auto primitive_array_handler = [dest](auto* field) {
543  dest->push_back(GetPyBufferFromArray(field));
544  };
545  auto composite_array_handler = [dest](auto* field, const char* format) {
546  dest->push_back(GetPyBufferFromArray(field, format));
547  };
548  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
549 }
550 
551 template <typename ThresholdType, typename LeafOutputType>
552 inline void
553 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
554  auto scalar_handler = [dest_fp](auto* field) {
555  WriteScalarToFile(field, dest_fp);
556  };
557  auto primitive_array_handler = [dest_fp](auto* field) {
558  WriteArrayToFile(field, dest_fp);
559  };
560  auto composite_array_handler = [dest_fp](auto* field, const char* format) {
561  WriteArrayToFile(field, dest_fp);
562  };
563  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
564 }
565 
566 template <typename ThresholdType, typename LeafOutputType>
567 inline void
568 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
569  std::vector<PyBufferFrame>::iterator end) {
570  if (std::distance(begin, end) != kNumFramePerTree) {
571  throw std::runtime_error("Wrong number of frames specified");
572  }
573  auto scalar_handler = [&begin](auto* field) {
574  InitScalarFromPyBuffer(field, *begin++);
575  };
576  auto array_handler = [&begin](auto* field) {
577  InitArrayFromPyBuffer(field, *begin++);
578  };
579  DeserializeTemplate(scalar_handler, array_handler);
580 }
581 
582 template <typename ThresholdType, typename LeafOutputType>
583 inline void
584 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
585  auto scalar_handler = [src_fp](auto* field) {
586  ReadScalarFromFile(field, src_fp);
587  };
588  auto array_handler = [src_fp](auto* field) {
589  ReadArrayFromFile(field, src_fp);
590  };
591  DeserializeTemplate(scalar_handler, array_handler);
592 }
593 
594 template <typename ThresholdType, typename LeafOutputType>
596  std::memset(this, 0, sizeof(Node));
597  cleft_ = cright_ = -1;
598  sindex_ = 0;
599  info_.leaf_value = static_cast<LeafOutputType>(0);
600  info_.threshold = static_cast<ThresholdType>(0);
601  data_count_ = 0;
602  sum_hess_ = gain_ = 0.0;
603  data_count_present_ = sum_hess_present_ = gain_present_ = false;
604  categories_list_right_child_ = false;
605  split_type_ = SplitFeatureType::kNone;
606  cmp_ = Operator::kNone;
607 }
608 
609 template <typename ThresholdType, typename LeafOutputType>
610 inline int
612  int nd = num_nodes++;
613  if (nodes_.Size() != static_cast<std::size_t>(nd)) {
614  throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes");
615  }
616  for (int nid = nd; nid < num_nodes; ++nid) {
617  leaf_vector_offset_.PushBack(leaf_vector_offset_.Back());
618  matching_categories_offset_.PushBack(matching_categories_offset_.Back());
619  nodes_.Resize(nodes_.Size() + 1);
620  nodes_.Back().Init();
621  }
622  return nd;
623 }
624 
625 template <typename ThresholdType, typename LeafOutputType>
626 inline void
628  num_nodes = 1;
629  leaf_vector_.Clear();
630  leaf_vector_offset_.Resize(2, 0);
631  matching_categories_.Clear();
632  matching_categories_offset_.Resize(2, 0);
633  nodes_.Resize(1);
634  nodes_.at(0).Init();
635  SetLeaf(0, static_cast<LeafOutputType>(0));
636 }
637 
638 template <typename ThresholdType, typename LeafOutputType>
639 inline void
641  const int cleft = this->AllocNode();
642  const int cright = this->AllocNode();
643  nodes_.at(nid).cleft_ = cleft;
644  nodes_.at(nid).cright_ = cright;
645 }
646 
647 template <typename ThresholdType, typename LeafOutputType>
648 inline std::vector<unsigned>
650  std::unordered_map<unsigned, bool> tmp;
651  for (int nid = 0; nid < num_nodes; ++nid) {
652  const SplitFeatureType type = SplitType(nid);
653  if (type != SplitFeatureType::kNone) {
654  const bool flag = (type == SplitFeatureType::kCategorical);
655  const uint32_t split_index = SplitIndex(nid);
656  if (tmp.count(split_index) == 0) {
657  tmp[split_index] = flag;
658  } else {
659  if (tmp[split_index] != flag) {
660  throw std::runtime_error("Feature " + std::to_string(split_index) +
661  " cannot be simultaneously be categorical and numerical.");
662  }
663  }
664  }
665  }
666  std::vector<unsigned> result;
667  for (const auto& kv : tmp) {
668  if (kv.second) {
669  result.push_back(kv.first);
670  }
671  }
672  std::sort(result.begin(), result.end());
673  return result;
674 }
675 
676 template <typename ThresholdType, typename LeafOutputType>
677 inline void
679  int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
680  Node& node = nodes_.at(nid);
681  if (split_index >= ((1U << 31U) - 1)) {
682  throw std::runtime_error("split_index too big");
683  }
684  if (default_left) split_index |= (1U << 31U);
685  node.sindex_ = split_index;
686  (node.info_).threshold = threshold;
687  node.cmp_ = cmp;
688  node.split_type_ = SplitFeatureType::kNumerical;
689  node.categories_list_right_child_ = false;
690 }
691 
692 template <typename ThresholdType, typename LeafOutputType>
693 inline void
695  int nid, unsigned split_index, bool default_left,
696  const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
697  if (split_index >= ((1U << 31U) - 1)) {
698  throw std::runtime_error("split_index too big");
699  }
700 
701  const std::size_t end_oft = matching_categories_offset_.Back();
702  const std::size_t new_end_oft = end_oft + categories_list.size();
703  if (end_oft != matching_categories_.Size()) {
704  throw std::runtime_error("Invariant violated");
705  }
706  if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
707  [end_oft](std::size_t x) { return (x == end_oft); })) {
708  throw std::runtime_error("Invariant violated");
709  }
710  // Hopefully we won't have to move any element as we add node_matching_categories for node nid
711  matching_categories_.Extend(categories_list);
712  if (new_end_oft != matching_categories_.Size()) {
713  throw std::runtime_error("Invariant violated");
714  }
715  std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
716  [new_end_oft](std::size_t& x) { x = new_end_oft; });
717  std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
718 
719  Node& node = nodes_.at(nid);
720  if (default_left) split_index |= (1U << 31U);
721  node.sindex_ = split_index;
722  node.split_type_ = SplitFeatureType::kCategorical;
723  node.categories_list_right_child_ = categories_list_right_child;
724 }
725 
726 template <typename ThresholdType, typename LeafOutputType>
727 inline void
728 Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
729  Node& node = nodes_.at(nid);
730  (node.info_).leaf_value = value;
731  node.cleft_ = -1;
732  node.cright_ = -1;
733  node.split_type_ = SplitFeatureType::kNone;
734 }
735 
736 template <typename ThresholdType, typename LeafOutputType>
737 inline void
739  int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
740  const std::size_t end_oft = leaf_vector_offset_.Back();
741  const std::size_t new_end_oft = end_oft + node_leaf_vector.size();
742  if (end_oft != leaf_vector_.Size()) {
743  throw std::runtime_error("Invariant violated");
744  }
745  if (!std::all_of(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
746  [end_oft](std::size_t x) { return (x == end_oft); })) {
747  throw std::runtime_error("Invariant violated");
748  }
749  // Hopefully we won't have to move any element as we add leaf vector elements for node nid
750  leaf_vector_.Extend(node_leaf_vector);
751  if (new_end_oft != leaf_vector_.Size()) {
752  throw std::runtime_error("Invariant violated");
753  }
754  std::for_each(&leaf_vector_offset_.at(nid + 1), leaf_vector_offset_.End(),
755  [new_end_oft](std::size_t& x) { x = new_end_oft; });
756 
757  Node& node = nodes_.at(nid);
758  node.cleft_ = -1;
759  node.cright_ = -1;
760  node.split_type_ = SplitFeatureType::kNone;
761 }
762 
763 template <typename ThresholdType, typename LeafOutputType>
764 inline std::unique_ptr<Model>
765 Model::Create() {
766  std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
767  model->threshold_type_ = TypeToInfo<ThresholdType>();
768  model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
769  return model;
770 }
771 
772 template <typename ThresholdType, typename LeafOutputType>
774  public:
775  inline static std::unique_ptr<Model> Dispatch() {
776  return Model::Create<ThresholdType, LeafOutputType>();
777  }
778 };
779 
780 inline std::unique_ptr<Model>
781 Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
782  return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
783 }
784 
785 template <typename ThresholdType, typename LeafOutputType>
787  public:
788  template <typename Func>
789  inline static auto Dispatch(Model* model, Func func) {
790  return func(*dynamic_cast<ModelImpl<ThresholdType, LeafOutputType>*>(model));
791  }
792 
793  template <typename Func>
794  inline static auto Dispatch(const Model* model, Func func) {
795  return func(*dynamic_cast<const ModelImpl<ThresholdType, LeafOutputType>*>(model));
796  }
797 };
798 
799 template <typename Func>
800 inline auto
801 Model::Dispatch(Func func) {
802  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
803 }
804 
805 template <typename Func>
806 inline auto
807 Model::Dispatch(Func func) const {
808  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
809 }
810 
811 template <typename HeaderPrimitiveFieldHandlerFunc>
812 inline void
813 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
814  header_primitive_field_handler(&major_ver_);
815  header_primitive_field_handler(&minor_ver_);
816  header_primitive_field_handler(&patch_ver_);
817  header_primitive_field_handler(&threshold_type_);
818  header_primitive_field_handler(&leaf_output_type_);
819 }
820 
821 template <typename HeaderPrimitiveFieldHandlerFunc>
822 inline void
823 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
824  TypeInfo& threshold_type, TypeInfo& leaf_output_type) {
825  int major_ver, minor_ver, patch_ver;
826  header_primitive_field_handler(&major_ver);
827  header_primitive_field_handler(&minor_ver);
828  header_primitive_field_handler(&patch_ver);
829  if (major_ver != TREELITE_VER_MAJOR || minor_ver != TREELITE_VER_MINOR) {
830  throw std::runtime_error("Cannot deserialize model from a different version of Treelite");
831  }
832  header_primitive_field_handler(&threshold_type);
833  header_primitive_field_handler(&leaf_output_type);
834 }
835 
836 template <typename ThresholdType, typename LeafOutputType>
837 template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
838  typename TreeHandlerFunc>
839 inline void
841  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
842  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
843  TreeHandlerFunc tree_handler) {
844  /* Header */
845  header_primitive_field_handler(&num_feature);
846  header_primitive_field_handler(&task_type);
847  header_primitive_field_handler(&average_tree_output);
848  header_composite_field_handler(&task_param, "T{=B=?xx=I=I}");
849  header_composite_field_handler(
850  &param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}");
851 
852  /* Body */
853  for (Tree<ThresholdType, LeafOutputType>& tree : trees) {
854  tree_handler(tree);
855  }
856 }
857 
858 template <typename ThresholdType, typename LeafOutputType>
859 template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
860 inline void
862  std::size_t num_tree,
863  HeaderFieldHandlerFunc header_field_handler,
864  TreeHandlerFunc tree_handler) {
865  /* Header */
866  header_field_handler(&num_feature);
867  header_field_handler(&task_type);
868  header_field_handler(&average_tree_output);
869  header_field_handler(&task_param);
870  header_field_handler(&param);
871  /* Body */
872  trees.clear();
873  for (std::size_t i = 0; i < num_tree; ++i) {
874  trees.emplace_back();
875  tree_handler(trees.back());
876  }
877 }
878 
879 template <typename ThresholdType, typename LeafOutputType>
880 inline void
881 ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
882  auto header_primitive_field_handler = [dest](auto* field) {
883  dest->push_back(GetPyBufferFromScalar(field));
884  };
885  auto header_composite_field_handler = [dest](auto* field, const char* format) {
886  dest->push_back(GetPyBufferFromScalar(field, format));
887  };
888  auto tree_handler = [dest](Tree<ThresholdType, LeafOutputType>& tree) {
889  tree.GetPyBuffer(dest);
890  };
891  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
892 }
893 
894 template <typename ThresholdType, typename LeafOutputType>
895 inline void
897  const auto num_tree = static_cast<uint64_t>(this->trees.size());
898  WriteScalarToFile(&num_tree, dest_fp);
899  auto header_primitive_field_handler = [dest_fp](auto* field) {
900  WriteScalarToFile(field, dest_fp);
901  };
902  auto header_composite_field_handler = [dest_fp](auto* field, const char* format) {
903  WriteScalarToFile(field, dest_fp);
904  };
905  auto tree_handler = [dest_fp](Tree<ThresholdType, LeafOutputType>& tree) {
906  tree.SerializeToFile(dest_fp);
907  };
908  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
909 }
910 
911 template <typename ThresholdType, typename LeafOutputType>
912 inline void
914  std::vector<PyBufferFrame>::iterator begin, std::vector<PyBufferFrame>::iterator end) {
915  const std::size_t num_frame = std::distance(begin, end);
916  constexpr std::size_t kNumFrameInHeader = 5;
917  if (num_frame < kNumFrameInHeader || (num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) {
918  throw std::runtime_error("Wrong number of frames");
919  }
920  const std::size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
921 
922  auto header_field_handler = [&begin](auto* field) {
923  InitScalarFromPyBuffer(field, *begin++);
924  };
925 
926  auto tree_handler = [&begin](Tree<ThresholdType, LeafOutputType>& tree) {
927  // Read the frames in the range [begin, begin + kNumFramePerTree) into the tree
928  tree.InitFromPyBuffer(begin, begin + kNumFramePerTree);
929  begin += kNumFramePerTree;
930  // Advance the iterator so that the next tree reads the next kNumFramePerTree frames
931  };
932 
933  DeserializeTemplate(num_tree, header_field_handler, tree_handler);
934 }
935 
936 template <typename ThresholdType, typename LeafOutputType>
937 inline void
939  uint64_t num_tree;
940  ReadScalarFromFile(&num_tree, src_fp);
941 
942  auto header_field_handler = [src_fp](auto* field) {
943  ReadScalarFromFile(field, src_fp);
944  };
945 
946  auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
947  tree.DeserializeFromFile(src_fp);
948  };
949 
950  DeserializeTemplate(num_tree, header_field_handler, tree_handler);
951 }
952 
953 inline void InitParamAndCheck(ModelParam* param,
954  const std::vector<std::pair<std::string, std::string>>& cfg) {
955  auto unknown = param->InitAllowUnknown(cfg);
956  if (!unknown.empty()) {
957  std::ostringstream oss;
958  for (const auto& kv : unknown) {
959  oss << kv.first << ", ";
960  }
961  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
962  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
963  }
964 }
965 
966 } // namespace treelite
967 #endif // TREELITE_TREE_IMPL_H_
SplitFeatureType
feature split type
Definition: base.h:22
void Init()
initialize the model with a single root node
Definition: tree_impl.h:627
TaskType
Enum type representing the task type.
Definition: tree.h:100
tree node
Definition: tree.h:200
in-memory representation of a decision tree
Definition: tree.h:197
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
thin wrapper for tree ensemble model
Definition: tree.h:632
Operator
comparison operators
Definition: base.h:26