Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <treelite/error.h>
11 #include <treelite/version.h>
12 #include <treelite/logging.h>
13 #include <algorithm>
14 #include <limits>
15 #include <memory>
16 #include <map>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 #include <unordered_map>
21 #include <sstream>
22 #include <iomanip>
23 #include <typeinfo>
24 #include <stdexcept>
25 #include <iostream>
26 #include <cstddef>
27 
28 namespace {
29 
30 template <typename T>
31 inline std::string GetString(T x) {
32  return std::to_string(x);
33 }
34 
35 template <>
36 inline std::string GetString<float>(float x) {
37  std::ostringstream oss;
38  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
39  return oss.str();
40 }
41 
42 template <>
43 inline std::string GetString<double>(double x) {
44  std::ostringstream oss;
45  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
46  return oss.str();
47 }
48 
49 } // anonymous namespace
50 
51 namespace treelite {
52 
53 template <typename T>
54 ContiguousArray<T>::ContiguousArray()
55  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
56 
57 template <typename T>
58 ContiguousArray<T>::~ContiguousArray() {
59  if (buffer_ && owned_buffer_) {
60  std::free(buffer_);
61  }
62 }
63 
64 template <typename T>
65 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
66  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
67  owned_buffer_(other.owned_buffer_) {
68  other.buffer_ = nullptr;
69  other.size_ = other.capacity_ = 0;
70 }
71 
72 template <typename T>
73 ContiguousArray<T>&
74 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
75  if (buffer_ && owned_buffer_) {
76  std::free(buffer_);
77  }
78  buffer_ = other.buffer_;
79  size_ = other.size_;
80  capacity_ = other.capacity_;
81  owned_buffer_ = other.owned_buffer_;
82  other.buffer_ = nullptr;
83  other.size_ = other.capacity_ = 0;
84  return *this;
85 }
86 
87 template <typename T>
88 inline ContiguousArray<T>
89 ContiguousArray<T>::Clone() const {
90  ContiguousArray clone;
91  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
92  if (!clone.buffer_) {
93  throw Error("Could not allocate memory for the clone");
94  }
95  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
96  clone.size_ = size_;
97  clone.capacity_ = capacity_;
98  clone.owned_buffer_ = true;
99  return clone;
100 }
101 
102 template <typename T>
103 inline void
104 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, std::size_t size) {
105  if (buffer_ && owned_buffer_) {
106  std::free(buffer_);
107  }
108  buffer_ = static_cast<T*>(prealloc_buf);
109  size_ = size;
110  capacity_ = size;
111  owned_buffer_ = false;
112 }
113 
114 template <typename T>
115 inline T*
116 ContiguousArray<T>::Data() {
117  return buffer_;
118 }
119 
120 template <typename T>
121 inline const T*
122 ContiguousArray<T>::Data() const {
123  return buffer_;
124 }
125 
126 template <typename T>
127 inline T*
128 ContiguousArray<T>::End() {
129  return &buffer_[Size()];
130 }
131 
132 template <typename T>
133 inline const T*
134 ContiguousArray<T>::End() const {
135  return &buffer_[Size()];
136 }
137 
138 template <typename T>
139 inline T&
140 ContiguousArray<T>::Back() {
141  return buffer_[Size() - 1];
142 }
143 
144 template <typename T>
145 inline const T&
146 ContiguousArray<T>::Back() const {
147  return buffer_[Size() - 1];
148 }
149 
150 template <typename T>
151 inline std::size_t
152 ContiguousArray<T>::Size() const {
153  return size_;
154 }
155 
156 template <typename T>
157 inline bool
158 ContiguousArray<T>::Empty() const {
159  return (Size() == 0);
160 }
161 
162 template <typename T>
163 inline void
164 ContiguousArray<T>::Reserve(std::size_t newsize) {
165  if (!owned_buffer_) {
166  throw Error("Cannot resize when using a foreign buffer; clone first");
167  }
168  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
169  if (!newbuf) {
170  throw Error("Could not expand buffer");
171  }
172  buffer_ = newbuf;
173  capacity_ = newsize;
174 }
175 
176 template <typename T>
177 inline void
178 ContiguousArray<T>::Resize(std::size_t newsize) {
179  if (!owned_buffer_) {
180  throw Error("Cannot resize when using a foreign buffer; clone first");
181  }
182  if (newsize > capacity_) {
183  std::size_t newcapacity = capacity_;
184  if (newcapacity == 0) {
185  newcapacity = 1;
186  }
187  while (newcapacity <= newsize) {
188  newcapacity *= 2;
189  }
190  Reserve(newcapacity);
191  }
192  size_ = newsize;
193 }
194 
195 template <typename T>
196 inline void
197 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
198  if (!owned_buffer_) {
199  throw Error("Cannot resize when using a foreign buffer; clone first");
200  }
201  std::size_t oldsize = Size();
202  Resize(newsize);
203  for (std::size_t i = oldsize; i < newsize; ++i) {
204  buffer_[i] = t;
205  }
206 }
207 
208 template <typename T>
209 inline void
210 ContiguousArray<T>::Clear() {
211  if (!owned_buffer_) {
212  throw Error("Cannot clear when using a foreign buffer; clone first");
213  }
214  Resize(0);
215 }
216 
217 template <typename T>
218 inline void
219 ContiguousArray<T>::PushBack(T t) {
220  if (!owned_buffer_) {
221  throw Error("Cannot add element when using a foreign buffer; clone first");
222  }
223  if (size_ == capacity_) {
224  Reserve(capacity_ * 2);
225  }
226  buffer_[size_++] = t;
227 }
228 
229 template <typename T>
230 inline void
231 ContiguousArray<T>::Extend(const std::vector<T>& other) {
232  if (!owned_buffer_) {
233  throw Error("Cannot add elements when using a foreign buffer; clone first");
234  }
235  if (other.empty()) {
236  return; // appending an empty vector is a no-op
237  }
238  std::size_t newsize = size_ + other.size();
239  if (newsize > capacity_) {
240  std::size_t newcapacity = capacity_;
241  if (newcapacity == 0) {
242  newcapacity = 1;
243  }
244  while (newcapacity <= newsize) {
245  newcapacity *= 2;
246  }
247  Reserve(newcapacity);
248  }
249  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
250  size_ = newsize;
251 }
252 
253 template <typename T>
254 inline T&
255 ContiguousArray<T>::operator[](std::size_t idx) {
256  return buffer_[idx];
257 }
258 
259 template <typename T>
260 inline const T&
261 ContiguousArray<T>::operator[](std::size_t idx) const {
262  return buffer_[idx];
263 }
264 
265 template <typename T>
266 inline T&
267 ContiguousArray<T>::at(std::size_t idx) {
268  if (idx >= Size()) {
269  throw Error("nid out of range");
270  }
271  return buffer_[idx];
272 }
273 
274 template <typename T>
275 inline const T&
276 ContiguousArray<T>::at(std::size_t idx) const {
277  if (idx >= Size()) {
278  throw Error("nid out of range");
279  }
280  return buffer_[idx];
281 }
282 
283 template <typename T>
284 inline T&
285 ContiguousArray<T>::at(int idx) {
286  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
287  throw Error("nid out of range");
288  }
289  return buffer_[static_cast<std::size_t>(idx)];
290 }
291 
292 template <typename T>
293 inline const T&
294 ContiguousArray<T>::at(int idx) const {
295  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
296  throw Error("nid out of range");
297  }
298  return buffer_[static_cast<std::size_t>(idx)];
299 }
300 
301 template<typename Container>
302 inline std::vector<std::pair<std::string, std::string> >
303 ModelParam::InitAllowUnknown(const Container& kwargs) {
304  std::vector<std::pair<std::string, std::string>> unknowns;
305  for (const auto& e : kwargs) {
306  if (e.first == "pred_transform") {
307  std::strncpy(this->pred_transform, e.second.c_str(),
308  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
309  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
310  } else if (e.first == "sigmoid_alpha") {
311  this->sigmoid_alpha = std::stof(e.second, nullptr);
312  } else if (e.first == "ratio_c") {
313  this->ratio_c = std::stof(e.second, nullptr);
314  } else if (e.first == "global_bias") {
315  this->global_bias = std::stof(e.second, nullptr);
316  }
317  }
318  return unknowns;
319 }
320 
321 inline std::map<std::string, std::string>
322 ModelParam::__DICT__() const {
323  std::map<std::string, std::string> ret;
324  ret.emplace("pred_transform", std::string(this->pred_transform));
325  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
326  ret.emplace("ratio_c", GetString(this->ratio_c));
327  ret.emplace("global_bias", GetString(this->global_bias));
328  return ret;
329 }
330 
331 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
332  std::size_t itemsize, std::size_t nitem) {
333  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
334 }
335 
336 // Infer format string from data type
337 template <typename T>
338 inline const char* InferFormatString() {
339  switch (sizeof(T)) {
340  case 1:
341  return (std::is_unsigned<T>::value ? "=B" : "=b");
342  case 2:
343  return (std::is_unsigned<T>::value ? "=H" : "=h");
344  case 4:
345  if (std::is_integral<T>::value) {
346  return (std::is_unsigned<T>::value ? "=L" : "=l");
347  } else {
348  if (!std::is_floating_point<T>::value) {
349  throw Error("Could not infer format string");
350  }
351  return "=f";
352  }
353  case 8:
354  if (std::is_integral<T>::value) {
355  return (std::is_unsigned<T>::value ? "=Q" : "=q");
356  } else {
357  if (!std::is_floating_point<T>::value) {
358  throw Error("Could not infer format string");
359  }
360  return "=d";
361  }
362  default:
363  throw Error("Unrecognized type");
364  }
365  return nullptr;
366 }
367 
368 template <typename T>
369 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
370  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
371 }
372 
373 template <typename T>
374 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
375  static_assert(std::is_arithmetic<T>::value,
376  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
377  return GetPyBufferFromArray(vec, InferFormatString<T>());
378 }
379 
380 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, std::size_t itemsize) {
381  return GetPyBufferFromArray(data, format, itemsize, 1);
382 }
383 
384 template <typename T>
385 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
386  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
387  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
388 }
389 
390 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
391  using T = std::underlying_type<TypeInfo>::type;
392  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
393 }
394 
395 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
396  using T = std::underlying_type<TaskType>::type;
397  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
398 }
399 
400 template <typename T>
401 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
402  static_assert(std::is_arithmetic<T>::value,
403  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
404  "specify format string manually");
405  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
406 }
407 
408 template <typename T>
409 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
410  if (sizeof(T) != frame.itemsize) {
411  throw Error("Incorrect itemsize");
412  }
413  vec->UseForeignBuffer(frame.buf, frame.nitem);
414 }
415 
416 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
417  using T = std::underlying_type<TypeInfo>::type;
418  if (sizeof(T) != buffer.itemsize) {
419  throw Error("Incorrect itemsize");
420  }
421  if (buffer.nitem != 1) {
422  throw Error("nitem must be 1 for a scalar");
423  }
424  T* t = static_cast<T*>(buffer.buf);
425  *scalar = static_cast<TypeInfo>(*t);
426 }
427 
428 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
429  using T = std::underlying_type<TaskType>::type;
430  if (sizeof(T) != buffer.itemsize) {
431  throw Error("Incorrect itemsize");
432  }
433  if (buffer.nitem != 1) {
434  throw Error("nitem must be 1 for a scalar");
435  }
436  T* t = static_cast<T*>(buffer.buf);
437  *scalar = static_cast<TaskType>(*t);
438 }
439 
440 template <typename T>
441 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
442  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
443  if (sizeof(T) != buffer.itemsize) {
444  throw Error("Incorrect itemsize");
445  }
446  if (buffer.nitem != 1) {
447  throw Error("nitem must be 1 for a scalar");
448  }
449  T* t = static_cast<T*>(buffer.buf);
450  *scalar = *t;
451 }
452 
453 template <typename T>
454 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
455  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
456  if (std::fread(scalar, sizeof(T), 1, fp) < 1) {
457  throw Error("Could not read a scalar");
458  }
459 }
460 
461 template <typename T>
462 inline void WriteScalarToFile(T* scalar, FILE* fp) {
463  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
464  if (std::fwrite(scalar, sizeof(T), 1, fp) < 1) {
465  throw Error("Could not write a scalar");
466  }
467 }
468 
469 template <typename T>
470 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
471  uint64_t nelem;
472  if (std::fread(&nelem, sizeof(nelem), 1, fp) < 1) {
473  throw Error("Could not read the number of elements");
474  }
475  vec->Clear();
476  vec->Resize(nelem);
477  if (nelem == 0) {
478  return; // handle empty arrays
479  }
480  const auto nelem_size_t = static_cast<std::size_t>(nelem);
481  if (std::fread(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
482  throw Error("Could not read an array");
483  }
484 }
485 
486 template <typename T>
487 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
488  static_assert(sizeof(uint64_t) >= sizeof(size_t), "size_t too large");
489  const auto nelem = static_cast<uint64_t>(vec->Size());
490  if (std::fwrite(&nelem, sizeof(nelem), 1, fp) < 1) {
491  throw Error("Could not write the number of elements");
492  }
493  if (nelem == 0) {
494  return; // handle empty arrays
495  }
496  const auto nelem_size_t = vec->Size();
497  if (std::fwrite(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
498  throw Error("Could not write an array");
499  }
500 }
501 
502 inline void SkipOptFieldInFile(FILE* fp) {
503  uint16_t elem_size;
504  uint64_t nelem;
505  ReadScalarFromFile(&elem_size, fp);
506  ReadScalarFromFile(&nelem, fp);
507 
508  const uint64_t nbytes = elem_size * nelem;
509  TREELITE_CHECK_LE(nbytes, std::numeric_limits<long>::max()); // NOLINT
510  if (std::fseek(fp, static_cast<long>(nbytes), SEEK_CUR) != 0) { // NOLINT
511  throw Error("Reached end of file");
512  }
513 }
514 
515 template <typename ThresholdType, typename LeafOutputType>
516 Tree<ThresholdType, LeafOutputType>::Tree(bool use_opt_field)
517  : use_opt_field_(use_opt_field)
518 {}
519 
520 template <typename ThresholdType, typename LeafOutputType>
521 inline Tree<ThresholdType, LeafOutputType>
522 Tree<ThresholdType, LeafOutputType>::Clone() const {
523  Tree<ThresholdType, LeafOutputType> tree;
524  tree.num_nodes = num_nodes;
525  tree.nodes_ = nodes_.Clone();
526  tree.leaf_vector_ = leaf_vector_.Clone();
527  tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
528  tree.leaf_vector_end_ = leaf_vector_end_.Clone();
529  tree.matching_categories_ = matching_categories_.Clone();
530  tree.matching_categories_offset_ = matching_categories_offset_.Clone();
531  return tree;
532 }
533 
534 template <typename ThresholdType, typename LeafOutputType>
535 inline const char*
536 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
537  if (std::is_same<ThresholdType, float>::value) {
538  return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
539  } else {
540  return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
541  }
542 }
543 
544 template <typename ThresholdType, typename LeafOutputType>
545 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
546 inline void
547 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
548  ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
549  CompositeArrayHandler composite_array_handler) {
550  scalar_handler(&num_nodes);
551  scalar_handler(&has_categorical_split_);
552  composite_array_handler(&nodes_, GetFormatStringForNode());
553  primitive_array_handler(&leaf_vector_);
554  primitive_array_handler(&leaf_vector_begin_);
555  primitive_array_handler(&leaf_vector_end_);
556  primitive_array_handler(&matching_categories_);
557  primitive_array_handler(&matching_categories_offset_);
558 
559  /* Extension slot 2: Per-tree optional fields -- to be added later */
560  num_opt_field_per_tree_ = 0;
561  scalar_handler(&num_opt_field_per_tree_);
562 
563  /* Extension slot 3: Per-node optional fields -- to be added later */
564  num_opt_field_per_node_ = 0;
565  scalar_handler(&num_opt_field_per_node_);
566 }
567 
568 template <typename ThresholdType, typename LeafOutputType>
569 template <typename ScalarHandler, typename ArrayHandler, typename SkipOptFieldHandlerFunc>
570 inline void
571 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
572  ScalarHandler scalar_handler,
573  ArrayHandler array_handler,
574  SkipOptFieldHandlerFunc skip_opt_field_handler) {
575  scalar_handler(&num_nodes);
576  scalar_handler(&has_categorical_split_);
577  array_handler(&nodes_);
578  if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
579  throw Error("Could not load the correct number of nodes");
580  }
581  array_handler(&leaf_vector_);
582  array_handler(&leaf_vector_begin_);
583  array_handler(&leaf_vector_end_);
584  array_handler(&matching_categories_);
585  array_handler(&matching_categories_offset_);
586 
587  if (use_opt_field_) {
588  /* Extension slot 2: Per-tree optional fields -- to be added later */
589  scalar_handler(&num_opt_field_per_tree_);
590  // Ignore extra fields; the input is likely from a later version of Treelite
591  for (int32_t i = 0; i < num_opt_field_per_tree_; ++i) {
592  skip_opt_field_handler();
593  }
594 
595  /* Extension slot 3: Per-node optional fields -- to be added later */
596  scalar_handler(&num_opt_field_per_node_);
597  // Ignore extra fields; the input is likely from a later version of Treelite
598  for (int32_t i = 0; i < num_opt_field_per_node_; ++i) {
599  skip_opt_field_handler();
600  }
601  } else {
602  // Legacy (version 2.4)
603  num_opt_field_per_tree_ = 0;
604  num_opt_field_per_node_ = 0;
605  }
606 }
607 
608 template <typename ThresholdType, typename LeafOutputType>
609 inline void
610 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
611  auto scalar_handler = [dest](auto* field) {
612  dest->push_back(GetPyBufferFromScalar(field));
613  };
614  auto primitive_array_handler = [dest](auto* field) {
615  dest->push_back(GetPyBufferFromArray(field));
616  };
617  auto composite_array_handler = [dest](auto* field, const char* format) {
618  dest->push_back(GetPyBufferFromArray(field, format));
619  };
620  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
621 }
622 
623 template <typename ThresholdType, typename LeafOutputType>
624 inline void
625 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
626  auto scalar_handler = [dest_fp](auto* field) {
627  WriteScalarToFile(field, dest_fp);
628  };
629  auto primitive_array_handler = [dest_fp](auto* field) {
630  WriteArrayToFile(field, dest_fp);
631  };
632  auto composite_array_handler = [dest_fp](auto* field, const char* format) {
633  WriteArrayToFile(field, dest_fp);
634  };
635  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
636 }
637 
638 template <typename ThresholdType, typename LeafOutputType>
639 inline std::vector<PyBufferFrame>::iterator
640 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it) {
641  std::vector<PyBufferFrame>::iterator new_it = it;
642  auto scalar_handler = [&new_it](auto* field) {
643  InitScalarFromPyBuffer(field, *(new_it++));
644  };
645  auto array_handler = [&new_it](auto* field) {
646  InitArrayFromPyBuffer(field, *(new_it++));
647  };
648  auto skip_opt_field_handler = [&new_it]() {
649  ++new_it;
650  };
651  DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
652  return new_it;
653 }
654 
655 template <typename ThresholdType, typename LeafOutputType>
656 inline void
657 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
658  auto scalar_handler = [src_fp](auto* field) {
659  ReadScalarFromFile(field, src_fp);
660  };
661  auto array_handler = [src_fp](auto* field) {
662  ReadArrayFromFile(field, src_fp);
663  };
664  auto skip_opt_field_handler = [src_fp]() {
665  SkipOptFieldInFile(src_fp);
666  };
667  DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
668 }
669 
670 template <typename ThresholdType, typename LeafOutputType>
672  std::memset(this, 0, sizeof(Node));
673  cleft_ = cright_ = -1;
674  sindex_ = 0;
675  info_.leaf_value = static_cast<LeafOutputType>(0);
676  info_.threshold = static_cast<ThresholdType>(0);
677  data_count_ = 0;
678  sum_hess_ = gain_ = 0.0;
679  data_count_present_ = sum_hess_present_ = gain_present_ = false;
680  categories_list_right_child_ = false;
681  split_type_ = SplitFeatureType::kNone;
682  cmp_ = Operator::kNone;
683 }
684 
685 template <typename ThresholdType, typename LeafOutputType>
686 inline int
688  int nd = num_nodes++;
689  if (nodes_.Size() != static_cast<std::size_t>(nd)) {
690  throw Error("Invariant violated: nodes_ contains incorrect number of nodes");
691  }
692  for (int nid = nd; nid < num_nodes; ++nid) {
693  leaf_vector_begin_.PushBack(0);
694  leaf_vector_end_.PushBack(0);
695  matching_categories_offset_.PushBack(matching_categories_offset_.Back());
696  nodes_.Resize(nodes_.Size() + 1);
697  nodes_.Back().Init();
698  }
699  return nd;
700 }
701 
702 template <typename ThresholdType, typename LeafOutputType>
703 inline void
705  num_nodes = 1;
706  has_categorical_split_ = false;
707  leaf_vector_.Clear();
708  leaf_vector_begin_.Resize(1, {});
709  leaf_vector_end_.Resize(1, {});
710  matching_categories_.Clear();
711  matching_categories_offset_.Resize(2, 0);
712  nodes_.Resize(1);
713  nodes_.at(0).Init();
714  SetLeaf(0, static_cast<LeafOutputType>(0));
715 }
716 
717 template <typename ThresholdType, typename LeafOutputType>
718 inline void
720  const int cleft = this->AllocNode();
721  const int cright = this->AllocNode();
722  nodes_.at(nid).cleft_ = cleft;
723  nodes_.at(nid).cright_ = cright;
724 }
725 
726 template <typename ThresholdType, typename LeafOutputType>
727 inline void
729  int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
730  Node& node = nodes_.at(nid);
731  if (split_index >= ((1U << 31U) - 1)) {
732  throw Error("split_index too big");
733  }
734  if (default_left) split_index |= (1U << 31U);
735  node.sindex_ = split_index;
736  (node.info_).threshold = threshold;
737  node.cmp_ = cmp;
738  node.split_type_ = SplitFeatureType::kNumerical;
739  node.categories_list_right_child_ = false;
740 }
741 
742 template <typename ThresholdType, typename LeafOutputType>
743 inline void
745  int nid, unsigned split_index, bool default_left,
746  const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
747  if (split_index >= ((1U << 31U) - 1)) {
748  throw Error("split_index too big");
749  }
750 
751  const std::size_t end_oft = matching_categories_offset_.Back();
752  const std::size_t new_end_oft = end_oft + categories_list.size();
753  if (end_oft != matching_categories_.Size()) {
754  throw Error("Invariant violated");
755  }
756  if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
757  [end_oft](std::size_t x) { return (x == end_oft); })) {
758  throw Error("Invariant violated");
759  }
760  // Hopefully we won't have to move any element as we add node_matching_categories for node nid
761  matching_categories_.Extend(categories_list);
762  if (new_end_oft != matching_categories_.Size()) {
763  throw Error("Invariant violated");
764  }
765  std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
766  [new_end_oft](std::size_t& x) { x = new_end_oft; });
767  if (!matching_categories_.Empty()) {
768  std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
769  }
770 
771  Node& node = nodes_.at(nid);
772  if (default_left) split_index |= (1U << 31U);
773  node.sindex_ = split_index;
774  node.split_type_ = SplitFeatureType::kCategorical;
775  node.categories_list_right_child_ = categories_list_right_child;
776 
777  has_categorical_split_ = true;
778 }
779 
780 template <typename ThresholdType, typename LeafOutputType>
781 inline void
782 Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
783  Node& node = nodes_.at(nid);
784  (node.info_).leaf_value = value;
785  node.cleft_ = -1;
786  node.cright_ = -1;
787  node.split_type_ = SplitFeatureType::kNone;
788 }
789 
790 template <typename ThresholdType, typename LeafOutputType>
791 inline void
793  int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
794  std::size_t begin = leaf_vector_.Size();
795  std::size_t end = begin + node_leaf_vector.size();
796  leaf_vector_.Extend(node_leaf_vector);
797  leaf_vector_begin_[nid] = begin;
798  leaf_vector_end_[nid] = end;
799  Node &node = nodes_.at(nid);
800  node.cleft_ = -1;
801  node.cright_ = -1;
802  node.split_type_ = SplitFeatureType::kNone;
803 }
804 
805 template <typename ThresholdType, typename LeafOutputType>
806 inline std::unique_ptr<Model>
807 Model::Create() {
808  std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
809  model->threshold_type_ = TypeToInfo<ThresholdType>();
810  model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
811  return model;
812 }
813 
814 template <typename ThresholdType, typename LeafOutputType>
816  public:
817  inline static std::unique_ptr<Model> Dispatch() {
818  return Model::Create<ThresholdType, LeafOutputType>();
819  }
820 };
821 
822 inline std::unique_ptr<Model>
823 Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
824  return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
825 }
826 
827 template <typename ThresholdType, typename LeafOutputType>
829  public:
830  template <typename Func>
831  inline static auto Dispatch(Model* model, Func func) {
832  return func(*dynamic_cast<ModelImpl<ThresholdType, LeafOutputType>*>(model));
833  }
834 
835  template <typename Func>
836  inline static auto Dispatch(const Model* model, Func func) {
837  return func(*dynamic_cast<const ModelImpl<ThresholdType, LeafOutputType>*>(model));
838  }
839 };
840 
841 template <typename Func>
842 inline auto
843 Model::Dispatch(Func func) {
844  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
845 }
846 
847 template <typename Func>
848 inline auto
849 Model::Dispatch(Func func) const {
850  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
851 }
852 
853 template <typename HeaderPrimitiveFieldHandlerFunc>
854 inline void
855 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
856  major_ver_ = TREELITE_VER_MAJOR;
857  minor_ver_ = TREELITE_VER_MINOR;
858  patch_ver_ = TREELITE_VER_PATCH;
859  header_primitive_field_handler(&major_ver_);
860  header_primitive_field_handler(&minor_ver_);
861  header_primitive_field_handler(&patch_ver_);
862  header_primitive_field_handler(&threshold_type_);
863  header_primitive_field_handler(&leaf_output_type_);
864 }
865 
866 template <typename HeaderPrimitiveFieldHandlerFunc>
867 inline void
868 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
869  int32_t& major_ver, int32_t& minor_ver, int32_t& patch_ver,
870  TypeInfo& threshold_type, TypeInfo& leaf_output_type) {
871  header_primitive_field_handler(&major_ver);
872  header_primitive_field_handler(&minor_ver);
873  header_primitive_field_handler(&patch_ver);
874  if (major_ver != TREELITE_VER_MAJOR && !(major_ver == 2 && minor_ver == 4)) {
875  std::ostringstream oss;
876  oss << "Cannot deserialize model from a different major Treelite version or "
877  << "a version before 2.4.0." << std::endl
878  << "Currently running Treelite version " << TREELITE_VER_MAJOR << "."
879  << TREELITE_VER_MINOR << "." << TREELITE_VER_PATCH << std::endl
880  << "The model checkpoint was generated from Treelite version " << major_ver << "."
881  << minor_ver << "." << patch_ver;
882  throw Error(oss.str());
883  }
884  header_primitive_field_handler(&threshold_type);
885  header_primitive_field_handler(&leaf_output_type);
886 }
887 
888 template <typename ThresholdType, typename LeafOutputType>
889 template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
890  typename TreeHandlerFunc>
891 inline void
893  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
894  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
895  TreeHandlerFunc tree_handler) {
896  /* Header */
897  header_primitive_field_handler(&num_feature);
898  header_primitive_field_handler(&task_type);
899  header_primitive_field_handler(&average_tree_output);
900  header_composite_field_handler(&task_param, "T{=B=?xx=I=I}");
901  header_composite_field_handler(
902  &param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f=f}");
903 
904  /* Extension Slot 1: Per-model optional fields -- to be added later */
905  num_opt_field_per_model_ = 0;
906  header_primitive_field_handler(&num_opt_field_per_model_);
907 
908  /* Body */
909  for (Tree<ThresholdType, LeafOutputType>& tree : trees) {
910  tree_handler(tree);
911  }
912 }
913 
914 template <typename ThresholdType, typename LeafOutputType>
915 template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc,
916  typename SkipOptFieldHandlerFunc>
917 inline void
919  std::size_t num_tree,
920  HeaderFieldHandlerFunc header_field_handler,
921  TreeHandlerFunc tree_handler,
922  SkipOptFieldHandlerFunc skip_opt_field_handler) {
923  /* Header */
924  header_field_handler(&num_feature);
925  header_field_handler(&task_type);
926  header_field_handler(&average_tree_output);
927  header_field_handler(&task_param);
928  header_field_handler(&param);
929 
930  /* Extension Slot 1: Per-model optional fields -- to be added later */
931  const bool use_opt_field = (major_ver_ >= 3);
932  if (use_opt_field) {
933  header_field_handler(&num_opt_field_per_model_);
934  // Ignore extra fields; the input is likely from a later version of Treelite
935  for (int32_t i = 0; i < num_opt_field_per_model_; ++i) {
936  skip_opt_field_handler();
937  }
938  } else {
939  // Legacy (version 2.4)
940  num_opt_field_per_model_ = 0;
941  }
942 
943  /* Body */
944  trees.clear();
945  for (std::size_t i = 0; i < num_tree; ++i) {
946  trees.emplace_back(use_opt_field);
947  tree_handler(trees.back());
948  }
949 }
950 
951 template <typename ThresholdType, typename LeafOutputType>
952 inline void
953 ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
954  num_tree_ = static_cast<uint64_t>(this->trees.size());
955  auto header_primitive_field_handler = [dest](auto* field) {
956  dest->push_back(GetPyBufferFromScalar(field));
957  };
958  auto header_composite_field_handler = [dest](auto* field, const char* format) {
959  dest->push_back(GetPyBufferFromScalar(field, format));
960  };
961  auto tree_handler = [dest](Tree<ThresholdType, LeafOutputType>& tree) {
962  tree.GetPyBuffer(dest);
963  };
964  header_primitive_field_handler(&num_tree_);
965  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
966 }
967 
968 template <typename ThresholdType, typename LeafOutputType>
969 inline void
971  num_tree_ = static_cast<uint64_t>(this->trees.size());
972  auto header_primitive_field_handler = [dest_fp](auto* field) {
973  WriteScalarToFile(field, dest_fp);
974  };
975  auto header_composite_field_handler = [dest_fp](auto* field, const char* format) {
976  WriteScalarToFile(field, dest_fp);
977  };
978  auto tree_handler = [dest_fp](Tree<ThresholdType, LeafOutputType>& tree) {
979  tree.SerializeToFile(dest_fp);
980  };
981  header_primitive_field_handler(&num_tree_);
982  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
983 }
984 
985 template <typename ThresholdType, typename LeafOutputType>
986 inline std::vector<PyBufferFrame>::iterator
988  std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) {
989  std::vector<PyBufferFrame>::iterator new_it = it;
990  auto header_field_handler = [&new_it](auto* field) {
991  InitScalarFromPyBuffer(field, *(new_it++));
992  };
993 
994  auto skip_opt_field_handler = [&new_it]() {
995  ++new_it;
996  };
997 
998  auto tree_handler = [&new_it](Tree<ThresholdType, LeafOutputType>& tree) {
999  new_it = tree.InitFromPyBuffer(new_it);
1000  };
1001 
1002  if (major_ver_ == 2) {
1003  // From version 2.4 (legacy)
1004  // num_tree has to be inferred from the number of frames received
1005  constexpr std::size_t kNumFrameInHeader = 5;
1006  constexpr std::size_t kNumFramePerTree = 8;
1007  num_tree_ = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
1008  } else {
1009  // From version 3.x
1010  // num_tree is now explicitly stored
1011  header_field_handler(&num_tree_);
1012  }
1013 
1014  DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1015  TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1016 
1017  return new_it;
1018 }
1019 
1020 template <typename ThresholdType, typename LeafOutputType>
1021 inline void
1023  ReadScalarFromFile(&num_tree_, src_fp);
1024 
1025  auto header_field_handler = [src_fp](auto* field) {
1026  ReadScalarFromFile(field, src_fp);
1027  };
1028 
1029  auto skip_opt_field_handler = [src_fp]() {
1030  SkipOptFieldInFile(src_fp);
1031  };
1032 
1033  auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
1034  tree.DeserializeFromFile(src_fp);
1035  };
1036 
1037  DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1038  TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1039 }
1040 
1041 inline void InitParamAndCheck(ModelParam* param,
1042  const std::vector<std::pair<std::string, std::string>>& cfg) {
1043  auto unknown = param->InitAllowUnknown(cfg);
1044  if (!unknown.empty()) {
1045  std::ostringstream oss;
1046  for (const auto& kv : unknown) {
1047  oss << kv.first << ", ";
1048  }
1049  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
1050  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
1051  }
1052 }
1053 
1054 } // namespace treelite
1055 #endif // TREELITE_TREE_IMPL_H_
void Init()
initialize the model with a single root node
Definition: tree_impl.h:704
Exception class that will be thrown by Treelite.
Definition: error.h:18
TaskType
Enum type representing the task type.
Definition: tree.h:107
tree node
Definition: tree.h:225
in-memory representation of a decision tree
Definition: tree.h:222
logging facility for Treelite
Exception class used throughout the Treelite codebase.
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:23
thin wrapper for tree ensemble model
Definition: tree.h:734
Operator
comparison operators
Definition: base.h:27