Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <algorithm>
11 #include <limits>
12 #include <memory>
13 #include <map>
14 #include <string>
15 #include <utility>
16 #include <vector>
17 #include <unordered_map>
18 #include <sstream>
19 #include <iomanip>
20 #include <typeinfo>
21 #include <stdexcept>
22 #include <iostream>
23 #include <cstddef>
24 
25 namespace {
26 
27 template <typename T>
28 inline std::string GetString(T x) {
29  return std::to_string(x);
30 }
31 
32 template <>
33 inline std::string GetString<float>(float x) {
34  std::ostringstream oss;
35  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
36  return oss.str();
37 }
38 
39 template <>
40 inline std::string GetString<double>(double x) {
41  std::ostringstream oss;
42  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
43  return oss.str();
44 }
45 
46 } // anonymous namespace
47 
48 namespace treelite {
49 
50 template <typename T>
51 ContiguousArray<T>::ContiguousArray()
52  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
53 
54 template <typename T>
55 ContiguousArray<T>::~ContiguousArray() {
56  if (buffer_ && owned_buffer_) {
57  std::free(buffer_);
58  }
59 }
60 
61 template <typename T>
62 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
63  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
64  owned_buffer_(other.owned_buffer_) {
65  other.buffer_ = nullptr;
66  other.size_ = other.capacity_ = 0;
67 }
68 
69 template <typename T>
70 ContiguousArray<T>&
71 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
72  if (buffer_ && owned_buffer_) {
73  std::free(buffer_);
74  }
75  buffer_ = other.buffer_;
76  size_ = other.size_;
77  capacity_ = other.capacity_;
78  owned_buffer_ = other.owned_buffer_;
79  other.buffer_ = nullptr;
80  other.size_ = other.capacity_ = 0;
81  return *this;
82 }
83 
84 template <typename T>
85 inline ContiguousArray<T>
86 ContiguousArray<T>::Clone() const {
87  ContiguousArray clone;
88  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
89  if (!clone.buffer_) {
90  throw std::runtime_error("Could not allocate memory for the clone");
91  }
92  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
93  clone.size_ = size_;
94  clone.capacity_ = capacity_;
95  clone.owned_buffer_ = true;
96  return clone;
97 }
98 
99 template <typename T>
100 inline void
101 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, std::size_t size) {
102  if (buffer_ && owned_buffer_) {
103  std::free(buffer_);
104  }
105  buffer_ = static_cast<T*>(prealloc_buf);
106  size_ = size;
107  capacity_ = size;
108  owned_buffer_ = false;
109 }
110 
111 template <typename T>
112 inline T*
113 ContiguousArray<T>::Data() {
114  return buffer_;
115 }
116 
117 template <typename T>
118 inline const T*
119 ContiguousArray<T>::Data() const {
120  return buffer_;
121 }
122 
123 template <typename T>
124 inline T*
125 ContiguousArray<T>::End() {
126  return &buffer_[Size()];
127 }
128 
129 template <typename T>
130 inline const T*
131 ContiguousArray<T>::End() const {
132  return &buffer_[Size()];
133 }
134 
135 template <typename T>
136 inline T&
137 ContiguousArray<T>::Back() {
138  return buffer_[Size() - 1];
139 }
140 
141 template <typename T>
142 inline const T&
143 ContiguousArray<T>::Back() const {
144  return buffer_[Size() - 1];
145 }
146 
147 template <typename T>
148 inline std::size_t
149 ContiguousArray<T>::Size() const {
150  return size_;
151 }
152 
153 template <typename T>
154 inline bool
155 ContiguousArray<T>::Empty() const {
156  return (Size() == 0);
157 }
158 
159 template <typename T>
160 inline void
161 ContiguousArray<T>::Reserve(std::size_t newsize) {
162  if (!owned_buffer_) {
163  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
164  }
165  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
166  if (!newbuf) {
167  throw std::runtime_error("Could not expand buffer");
168  }
169  buffer_ = newbuf;
170  capacity_ = newsize;
171 }
172 
173 template <typename T>
174 inline void
175 ContiguousArray<T>::Resize(std::size_t newsize) {
176  if (!owned_buffer_) {
177  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
178  }
179  if (newsize > capacity_) {
180  std::size_t newcapacity = capacity_;
181  if (newcapacity == 0) {
182  newcapacity = 1;
183  }
184  while (newcapacity <= newsize) {
185  newcapacity *= 2;
186  }
187  Reserve(newcapacity);
188  }
189  size_ = newsize;
190 }
191 
192 template <typename T>
193 inline void
194 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
195  if (!owned_buffer_) {
196  throw std::runtime_error("Cannot resize when using a foreign buffer; clone first");
197  }
198  std::size_t oldsize = Size();
199  Resize(newsize);
200  for (std::size_t i = oldsize; i < newsize; ++i) {
201  buffer_[i] = t;
202  }
203 }
204 
205 template <typename T>
206 inline void
207 ContiguousArray<T>::Clear() {
208  if (!owned_buffer_) {
209  throw std::runtime_error("Cannot clear when using a foreign buffer; clone first");
210  }
211  Resize(0);
212 }
213 
214 template <typename T>
215 inline void
216 ContiguousArray<T>::PushBack(T t) {
217  if (!owned_buffer_) {
218  throw std::runtime_error("Cannot add element when using a foreign buffer; clone first");
219  }
220  if (size_ == capacity_) {
221  Reserve(capacity_ * 2);
222  }
223  buffer_[size_++] = t;
224 }
225 
226 template <typename T>
227 inline void
228 ContiguousArray<T>::Extend(const std::vector<T>& other) {
229  if (!owned_buffer_) {
230  throw std::runtime_error("Cannot add elements when using a foreign buffer; clone first");
231  }
232  std::size_t newsize = size_ + other.size();
233  if (newsize > capacity_) {
234  std::size_t newcapacity = capacity_;
235  if (newcapacity == 0) {
236  newcapacity = 1;
237  }
238  while (newcapacity <= newsize) {
239  newcapacity *= 2;
240  }
241  Reserve(newcapacity);
242  }
243  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
244  size_ = newsize;
245 }
246 
247 template <typename T>
248 inline T&
249 ContiguousArray<T>::operator[](std::size_t idx) {
250  return buffer_[idx];
251 }
252 
253 template <typename T>
254 inline const T&
255 ContiguousArray<T>::operator[](std::size_t idx) const {
256  return buffer_[idx];
257 }
258 
259 template <typename T>
260 inline T&
261 ContiguousArray<T>::at(std::size_t idx) {
262  if (idx >= Size()) {
263  throw std::runtime_error("nid out of range");
264  }
265  return buffer_[idx];
266 }
267 
268 template <typename T>
269 inline const T&
270 ContiguousArray<T>::at(std::size_t idx) const {
271  if (idx >= Size()) {
272  throw std::runtime_error("nid out of range");
273  }
274  return buffer_[idx];
275 }
276 
277 template <typename T>
278 inline T&
279 ContiguousArray<T>::at(int idx) {
280  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
281  throw std::runtime_error("nid out of range");
282  }
283  return buffer_[static_cast<std::size_t>(idx)];
284 }
285 
286 template <typename T>
287 inline const T&
288 ContiguousArray<T>::at(int idx) const {
289  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
290  throw std::runtime_error("nid out of range");
291  }
292  return buffer_[static_cast<std::size_t>(idx)];
293 }
294 
295 template<typename Container>
296 inline std::vector<std::pair<std::string, std::string> >
297 ModelParam::InitAllowUnknown(const Container& kwargs) {
298  std::vector<std::pair<std::string, std::string>> unknowns;
299  for (const auto& e : kwargs) {
300  if (e.first == "pred_transform") {
301  std::strncpy(this->pred_transform, e.second.c_str(),
302  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
303  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
304  } else if (e.first == "sigmoid_alpha") {
305  this->sigmoid_alpha = std::stof(e.second, nullptr);
306  } else if (e.first == "global_bias") {
307  this->global_bias = std::stof(e.second, nullptr);
308  }
309  }
310  return unknowns;
311 }
312 
313 inline std::map<std::string, std::string>
314 ModelParam::__DICT__() const {
315  std::map<std::string, std::string> ret;
316  ret.emplace("pred_transform", std::string(this->pred_transform));
317  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
318  ret.emplace("global_bias", GetString(this->global_bias));
319  return ret;
320 }
321 
322 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
323  std::size_t itemsize, std::size_t nitem) {
324  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
325 }
326 
327 // Infer format string from data type
328 template <typename T>
329 inline const char* InferFormatString() {
330  switch (sizeof(T)) {
331  case 1:
332  return (std::is_unsigned<T>::value ? "=B" : "=b");
333  case 2:
334  return (std::is_unsigned<T>::value ? "=H" : "=h");
335  case 4:
336  if (std::is_integral<T>::value) {
337  return (std::is_unsigned<T>::value ? "=L" : "=l");
338  } else {
339  if (!std::is_floating_point<T>::value) {
340  throw std::runtime_error("Could not infer format string");
341  }
342  return "=f";
343  }
344  case 8:
345  if (std::is_integral<T>::value) {
346  return (std::is_unsigned<T>::value ? "=Q" : "=q");
347  } else {
348  if (!std::is_floating_point<T>::value) {
349  throw std::runtime_error("Could not infer format string");
350  }
351  return "=d";
352  }
353  default:
354  throw std::runtime_error("Unrecognized type");
355  }
356  return nullptr;
357 }
358 
359 template <typename T>
360 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
361  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
362 }
363 
364 template <typename T>
365 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
366  static_assert(std::is_arithmetic<T>::value,
367  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
368  return GetPyBufferFromArray(vec, InferFormatString<T>());
369 }
370 
371 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, std::size_t itemsize) {
372  return GetPyBufferFromArray(data, format, itemsize, 1);
373 }
374 
375 template <typename T>
376 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
377  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
378  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
379 }
380 
381 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
382  using T = std::underlying_type<TypeInfo>::type;
383  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
384 }
385 
386 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
387  using T = std::underlying_type<TaskType>::type;
388  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
389 }
390 
391 template <typename T>
392 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
393  static_assert(std::is_arithmetic<T>::value,
394  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
395  "specify format string manually");
396  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
397 }
398 
399 template <typename T>
400 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
401  if (sizeof(T) != frame.itemsize) {
402  throw std::runtime_error("Incorrect itemsize");
403  }
404  vec->UseForeignBuffer(frame.buf, frame.nitem);
405 }
406 
407 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
408  using T = std::underlying_type<TypeInfo>::type;
409  if (sizeof(T) != buffer.itemsize) {
410  throw std::runtime_error("Incorrect itemsize");
411  }
412  if (buffer.nitem != 1) {
413  throw std::runtime_error("nitem must be 1 for a scalar");
414  }
415  T* t = static_cast<T*>(buffer.buf);
416  *scalar = static_cast<TypeInfo>(*t);
417 }
418 
419 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
420  using T = std::underlying_type<TaskType>::type;
421  if (sizeof(T) != buffer.itemsize) {
422  throw std::runtime_error("Incorrect itemsize");
423  }
424  if (buffer.nitem != 1) {
425  throw std::runtime_error("nitem must be 1 for a scalar");
426  }
427  T* t = static_cast<T*>(buffer.buf);
428  *scalar = static_cast<TaskType>(*t);
429 }
430 
431 template <typename T>
432 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
433  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
434  if (sizeof(T) != buffer.itemsize) {
435  throw std::runtime_error("Incorrect itemsize");
436  }
437  if (buffer.nitem != 1) {
438  throw std::runtime_error("nitem must be 1 for a scalar");
439  }
440  T* t = static_cast<T*>(buffer.buf);
441  *scalar = *t;
442 }
443 
444 template <typename T>
445 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
446  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
447  if (std::fread(scalar, sizeof(T), 1, fp) < 1) {
448  throw std::runtime_error("Could not read a scalar");
449  }
450 }
451 
452 template <typename T>
453 inline void WriteScalarToFile(T* scalar, FILE* fp) {
454  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
455  if (std::fwrite(scalar, sizeof(T), 1, fp) < 1) {
456  throw std::runtime_error("Could not write a scalar");
457  }
458 }
459 
460 template <typename T>
461 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
462  uint64_t nelem;
463  if (std::fread(&nelem, sizeof(nelem), 1, fp) < 1) {
464  throw std::runtime_error("Could not read the number of elements");
465  }
466  vec->Clear();
467  vec->Resize(nelem);
468  const auto nelem_size_t = static_cast<std::size_t>(nelem);
469  if (std::fread(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
470  throw std::runtime_error("Could not read an array");
471  }
472 }
473 
474 template <typename T>
475 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
476  static_assert(sizeof(uint64_t) >= sizeof(size_t), "size_t too large");
477  const auto nelem = static_cast<uint64_t>(vec->Size());
478  if (std::fwrite(&nelem, sizeof(nelem), 1, fp) < 1) {
479  throw std::runtime_error("Could not write the number of elements");
480  }
481  const auto nelem_size_t = vec->Size();
482  if (std::fwrite(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
483  throw std::runtime_error("Could not write an array");
484  }
485 }
486 
487 template <typename ThresholdType, typename LeafOutputType>
488 inline Tree<ThresholdType, LeafOutputType>
489 Tree<ThresholdType, LeafOutputType>::Clone() const {
490  Tree<ThresholdType, LeafOutputType> tree;
491  tree.num_nodes = num_nodes;
492  tree.nodes_ = nodes_.Clone();
493  tree.leaf_vector_ = leaf_vector_.Clone();
494  tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
495  tree.leaf_vector_end_ = leaf_vector_end_.Clone();
496  tree.matching_categories_ = matching_categories_.Clone();
497  tree.matching_categories_offset_ = matching_categories_offset_.Clone();
498  return tree;
499 }
500 
501 template <typename ThresholdType, typename LeafOutputType>
502 inline const char*
503 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
504  if (std::is_same<ThresholdType, float>::value) {
505  return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
506  } else {
507  return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
508  }
509 }
510 
511 constexpr std::size_t kNumFramePerTree = 7;
512 
513 template <typename ThresholdType, typename LeafOutputType>
514 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
515 inline void
516 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
517  ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
518  CompositeArrayHandler composite_array_handler) {
519  scalar_handler(&num_nodes);
520  composite_array_handler(&nodes_, GetFormatStringForNode());
521  primitive_array_handler(&leaf_vector_);
522  primitive_array_handler(&leaf_vector_begin_);
523  primitive_array_handler(&leaf_vector_end_);
524  primitive_array_handler(&matching_categories_);
525  primitive_array_handler(&matching_categories_offset_);
526 }
527 
528 template <typename ThresholdType, typename LeafOutputType>
529 template <typename ScalarHandler, typename ArrayHandler>
530 inline void
531 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
532  ScalarHandler scalar_handler, ArrayHandler array_handler) {
533  scalar_handler(&num_nodes);
534  array_handler(&nodes_);
535  if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
536  throw std::runtime_error("Could not load the correct number of nodes");
537  }
538  array_handler(&leaf_vector_);
539  array_handler(&leaf_vector_begin_);
540  array_handler(&leaf_vector_end_);
541  array_handler(&matching_categories_);
542  array_handler(&matching_categories_offset_);
543 }
544 
545 template <typename ThresholdType, typename LeafOutputType>
546 inline void
547 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
548  auto scalar_handler = [dest](auto* field) {
549  dest->push_back(GetPyBufferFromScalar(field));
550  };
551  auto primitive_array_handler = [dest](auto* field) {
552  dest->push_back(GetPyBufferFromArray(field));
553  };
554  auto composite_array_handler = [dest](auto* field, const char* format) {
555  dest->push_back(GetPyBufferFromArray(field, format));
556  };
557  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
558 }
559 
560 template <typename ThresholdType, typename LeafOutputType>
561 inline void
562 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
563  auto scalar_handler = [dest_fp](auto* field) {
564  WriteScalarToFile(field, dest_fp);
565  };
566  auto primitive_array_handler = [dest_fp](auto* field) {
567  WriteArrayToFile(field, dest_fp);
568  };
569  auto composite_array_handler = [dest_fp](auto* field, const char* format) {
570  WriteArrayToFile(field, dest_fp);
571  };
572  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
573 }
574 
575 template <typename ThresholdType, typename LeafOutputType>
576 inline void
577 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
578  std::vector<PyBufferFrame>::iterator end) {
579  if (std::distance(begin, end) != kNumFramePerTree) {
580  throw std::runtime_error("Wrong number of frames specified");
581  }
582  auto scalar_handler = [&begin](auto* field) {
583  InitScalarFromPyBuffer(field, *begin++);
584  };
585  auto array_handler = [&begin](auto* field) {
586  InitArrayFromPyBuffer(field, *begin++);
587  };
588  DeserializeTemplate(scalar_handler, array_handler);
589 }
590 
591 template <typename ThresholdType, typename LeafOutputType>
592 inline void
593 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
594  auto scalar_handler = [src_fp](auto* field) {
595  ReadScalarFromFile(field, src_fp);
596  };
597  auto array_handler = [src_fp](auto* field) {
598  ReadArrayFromFile(field, src_fp);
599  };
600  DeserializeTemplate(scalar_handler, array_handler);
601 }
602 
603 template <typename ThresholdType, typename LeafOutputType>
605  std::memset(this, 0, sizeof(Node));
606  cleft_ = cright_ = -1;
607  sindex_ = 0;
608  info_.leaf_value = static_cast<LeafOutputType>(0);
609  info_.threshold = static_cast<ThresholdType>(0);
610  data_count_ = 0;
611  sum_hess_ = gain_ = 0.0;
612  data_count_present_ = sum_hess_present_ = gain_present_ = false;
613  categories_list_right_child_ = false;
614  split_type_ = SplitFeatureType::kNone;
615  cmp_ = Operator::kNone;
616 }
617 
618 template <typename ThresholdType, typename LeafOutputType>
619 inline int
621  int nd = num_nodes++;
622  if (nodes_.Size() != static_cast<std::size_t>(nd)) {
623  throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes");
624  }
625  for (int nid = nd; nid < num_nodes; ++nid) {
626  leaf_vector_begin_.PushBack(0);
627  leaf_vector_end_.PushBack(0);
628  matching_categories_offset_.PushBack(matching_categories_offset_.Back());
629  nodes_.Resize(nodes_.Size() + 1);
630  nodes_.Back().Init();
631  }
632  return nd;
633 }
634 
635 template <typename ThresholdType, typename LeafOutputType>
636 inline void
638  num_nodes = 1;
639  leaf_vector_.Clear();
640  leaf_vector_begin_.Resize(1, {});
641  leaf_vector_end_.Resize(1, {});
642  matching_categories_.Clear();
643  matching_categories_offset_.Resize(2, 0);
644  nodes_.Resize(1);
645  nodes_.at(0).Init();
646  SetLeaf(0, static_cast<LeafOutputType>(0));
647 }
648 
649 template <typename ThresholdType, typename LeafOutputType>
650 inline void
652  const int cleft = this->AllocNode();
653  const int cright = this->AllocNode();
654  nodes_.at(nid).cleft_ = cleft;
655  nodes_.at(nid).cright_ = cright;
656 }
657 
658 template <typename ThresholdType, typename LeafOutputType>
659 inline std::vector<unsigned>
661  std::unordered_map<unsigned, bool> tmp;
662  for (int nid = 0; nid < num_nodes; ++nid) {
663  const SplitFeatureType type = SplitType(nid);
664  if (type != SplitFeatureType::kNone) {
665  const bool flag = (type == SplitFeatureType::kCategorical);
666  const uint32_t split_index = SplitIndex(nid);
667  if (tmp.count(split_index) == 0) {
668  tmp[split_index] = flag;
669  } else {
670  if (tmp[split_index] != flag) {
671  throw std::runtime_error("Feature " + std::to_string(split_index) +
672  " cannot be simultaneously be categorical and numerical.");
673  }
674  }
675  }
676  }
677  std::vector<unsigned> result;
678  for (const auto& kv : tmp) {
679  if (kv.second) {
680  result.push_back(kv.first);
681  }
682  }
683  std::sort(result.begin(), result.end());
684  return result;
685 }
686 
687 template <typename ThresholdType, typename LeafOutputType>
688 inline void
690  int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
691  Node& node = nodes_.at(nid);
692  if (split_index >= ((1U << 31U) - 1)) {
693  throw std::runtime_error("split_index too big");
694  }
695  if (default_left) split_index |= (1U << 31U);
696  node.sindex_ = split_index;
697  (node.info_).threshold = threshold;
698  node.cmp_ = cmp;
699  node.split_type_ = SplitFeatureType::kNumerical;
700  node.categories_list_right_child_ = false;
701 }
702 
703 template <typename ThresholdType, typename LeafOutputType>
704 inline void
706  int nid, unsigned split_index, bool default_left,
707  const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
708  if (split_index >= ((1U << 31U) - 1)) {
709  throw std::runtime_error("split_index too big");
710  }
711 
712  const std::size_t end_oft = matching_categories_offset_.Back();
713  const std::size_t new_end_oft = end_oft + categories_list.size();
714  if (end_oft != matching_categories_.Size()) {
715  throw std::runtime_error("Invariant violated");
716  }
717  if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
718  [end_oft](std::size_t x) { return (x == end_oft); })) {
719  throw std::runtime_error("Invariant violated");
720  }
721  // Hopefully we won't have to move any element as we add node_matching_categories for node nid
722  matching_categories_.Extend(categories_list);
723  if (new_end_oft != matching_categories_.Size()) {
724  throw std::runtime_error("Invariant violated");
725  }
726  std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
727  [new_end_oft](std::size_t& x) { x = new_end_oft; });
728  if (!matching_categories_.Empty()) {
729  std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
730  }
731 
732  Node& node = nodes_.at(nid);
733  if (default_left) split_index |= (1U << 31U);
734  node.sindex_ = split_index;
735  node.split_type_ = SplitFeatureType::kCategorical;
736  node.categories_list_right_child_ = categories_list_right_child;
737 }
738 
739 template <typename ThresholdType, typename LeafOutputType>
740 inline void
741 Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
742  Node& node = nodes_.at(nid);
743  (node.info_).leaf_value = value;
744  node.cleft_ = -1;
745  node.cright_ = -1;
746  node.split_type_ = SplitFeatureType::kNone;
747 }
748 
749 template <typename ThresholdType, typename LeafOutputType>
750 inline void
752  int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
753  std::size_t begin = leaf_vector_.Size();
754  std::size_t end = begin + node_leaf_vector.size();
755  leaf_vector_.Extend(node_leaf_vector);
756  leaf_vector_begin_[nid] = begin;
757  leaf_vector_end_[nid] = end;
758  Node &node = nodes_.at(nid);
759  node.cleft_ = -1;
760  node.cright_ = -1;
761  node.split_type_ = SplitFeatureType::kNone;
762 }
763 
764 template <typename ThresholdType, typename LeafOutputType>
765 inline std::unique_ptr<Model>
766 Model::Create() {
767  std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
768  model->threshold_type_ = TypeToInfo<ThresholdType>();
769  model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
770  return model;
771 }
772 
773 template <typename ThresholdType, typename LeafOutputType>
775  public:
776  inline static std::unique_ptr<Model> Dispatch() {
777  return Model::Create<ThresholdType, LeafOutputType>();
778  }
779 };
780 
781 inline std::unique_ptr<Model>
782 Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
783  return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
784 }
785 
786 template <typename ThresholdType, typename LeafOutputType>
788  public:
789  template <typename Func>
790  inline static auto Dispatch(Model* model, Func func) {
791  return func(*dynamic_cast<ModelImpl<ThresholdType, LeafOutputType>*>(model));
792  }
793 
794  template <typename Func>
795  inline static auto Dispatch(const Model* model, Func func) {
796  return func(*dynamic_cast<const ModelImpl<ThresholdType, LeafOutputType>*>(model));
797  }
798 };
799 
800 template <typename Func>
801 inline auto
802 Model::Dispatch(Func func) {
803  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
804 }
805 
806 template <typename Func>
807 inline auto
808 Model::Dispatch(Func func) const {
809  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
810 }
811 
812 template <typename HeaderPrimitiveFieldHandlerFunc>
813 inline void
814 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
815  header_primitive_field_handler(&major_ver_);
816  header_primitive_field_handler(&minor_ver_);
817  header_primitive_field_handler(&patch_ver_);
818  header_primitive_field_handler(&threshold_type_);
819  header_primitive_field_handler(&leaf_output_type_);
820 }
821 
822 template <typename HeaderPrimitiveFieldHandlerFunc>
823 inline void
824 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
825  TypeInfo& threshold_type, TypeInfo& leaf_output_type) {
826  int major_ver, minor_ver, patch_ver;
827  header_primitive_field_handler(&major_ver);
828  header_primitive_field_handler(&minor_ver);
829  header_primitive_field_handler(&patch_ver);
830  if (major_ver != TREELITE_VER_MAJOR || minor_ver != TREELITE_VER_MINOR) {
831  throw std::runtime_error("Cannot deserialize model from a different version of Treelite");
832  }
833  header_primitive_field_handler(&threshold_type);
834  header_primitive_field_handler(&leaf_output_type);
835 }
836 
837 template <typename ThresholdType, typename LeafOutputType>
838 template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
839  typename TreeHandlerFunc>
840 inline void
842  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
843  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
844  TreeHandlerFunc tree_handler) {
845  /* Header */
846  header_primitive_field_handler(&num_feature);
847  header_primitive_field_handler(&task_type);
848  header_primitive_field_handler(&average_tree_output);
849  header_composite_field_handler(&task_param, "T{=B=?xx=I=I}");
850  header_composite_field_handler(
851  &param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}");
852 
853  /* Body */
854  for (Tree<ThresholdType, LeafOutputType>& tree : trees) {
855  tree_handler(tree);
856  }
857 }
858 
859 template <typename ThresholdType, typename LeafOutputType>
860 template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
861 inline void
863  std::size_t num_tree,
864  HeaderFieldHandlerFunc header_field_handler,
865  TreeHandlerFunc tree_handler) {
866  /* Header */
867  header_field_handler(&num_feature);
868  header_field_handler(&task_type);
869  header_field_handler(&average_tree_output);
870  header_field_handler(&task_param);
871  header_field_handler(&param);
872  /* Body */
873  trees.clear();
874  for (std::size_t i = 0; i < num_tree; ++i) {
875  trees.emplace_back();
876  tree_handler(trees.back());
877  }
878 }
879 
880 template <typename ThresholdType, typename LeafOutputType>
881 inline void
882 ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
883  auto header_primitive_field_handler = [dest](auto* field) {
884  dest->push_back(GetPyBufferFromScalar(field));
885  };
886  auto header_composite_field_handler = [dest](auto* field, const char* format) {
887  dest->push_back(GetPyBufferFromScalar(field, format));
888  };
889  auto tree_handler = [dest](Tree<ThresholdType, LeafOutputType>& tree) {
890  tree.GetPyBuffer(dest);
891  };
892  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
893 }
894 
895 template <typename ThresholdType, typename LeafOutputType>
896 inline void
898  const auto num_tree = static_cast<uint64_t>(this->trees.size());
899  WriteScalarToFile(&num_tree, dest_fp);
900  auto header_primitive_field_handler = [dest_fp](auto* field) {
901  WriteScalarToFile(field, dest_fp);
902  };
903  auto header_composite_field_handler = [dest_fp](auto* field, const char* format) {
904  WriteScalarToFile(field, dest_fp);
905  };
906  auto tree_handler = [dest_fp](Tree<ThresholdType, LeafOutputType>& tree) {
907  tree.SerializeToFile(dest_fp);
908  };
909  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
910 }
911 
912 template <typename ThresholdType, typename LeafOutputType>
913 inline void
915  std::vector<PyBufferFrame>::iterator begin, std::vector<PyBufferFrame>::iterator end) {
916  const std::size_t num_frame = std::distance(begin, end);
917  constexpr std::size_t kNumFrameInHeader = 5;
918  if (num_frame < kNumFrameInHeader || (num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) {
919  throw std::runtime_error("Wrong number of frames");
920  }
921  const std::size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
922 
923  auto header_field_handler = [&begin](auto* field) {
924  InitScalarFromPyBuffer(field, *begin++);
925  };
926 
927  auto tree_handler = [&begin](Tree<ThresholdType, LeafOutputType>& tree) {
928  // Read the frames in the range [begin, begin + kNumFramePerTree) into the tree
929  tree.InitFromPyBuffer(begin, begin + kNumFramePerTree);
930  begin += kNumFramePerTree;
931  // Advance the iterator so that the next tree reads the next kNumFramePerTree frames
932  };
933 
934  DeserializeTemplate(num_tree, header_field_handler, tree_handler);
935 }
936 
937 template <typename ThresholdType, typename LeafOutputType>
938 inline void
940  uint64_t num_tree;
941  ReadScalarFromFile(&num_tree, src_fp);
942 
943  auto header_field_handler = [src_fp](auto* field) {
944  ReadScalarFromFile(field, src_fp);
945  };
946 
947  auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
948  tree.DeserializeFromFile(src_fp);
949  };
950 
951  DeserializeTemplate(num_tree, header_field_handler, tree_handler);
952 }
953 
954 inline void InitParamAndCheck(ModelParam* param,
955  const std::vector<std::pair<std::string, std::string>>& cfg) {
956  auto unknown = param->InitAllowUnknown(cfg);
957  if (!unknown.empty()) {
958  std::ostringstream oss;
959  for (const auto& kv : unknown) {
960  oss << kv.first << ", ";
961  }
962  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
963  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
964  }
965 }
966 
967 } // namespace treelite
968 #endif // TREELITE_TREE_IMPL_H_
SplitFeatureType
feature split type
Definition: base.h:22
void Init()
initialize the model with a single root node
Definition: tree_impl.h:637
TaskType
Enum type representing the task type.
Definition: tree.h:98
tree node
Definition: tree.h:216
in-memory representation of a decision tree
Definition: tree.h:213
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:22
thin wrapper for tree ensemble model
Definition: tree.h:647
Operator
comparison operators
Definition: base.h:26