Treelite
tree_impl.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_TREE_IMPL_H_
8 #define TREELITE_TREE_IMPL_H_
9 
10 #include <treelite/error.h>
11 #include <treelite/version.h>
12 #include <treelite/logging.h>
13 #include <algorithm>
14 #include <limits>
15 #include <memory>
16 #include <map>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 #include <unordered_map>
21 #include <sstream>
22 #include <iomanip>
23 #include <typeinfo>
24 #include <stdexcept>
25 #include <iostream>
26 #include <cstddef>
27 
28 namespace {
29 
30 template <typename T>
31 inline std::string GetString(T x) {
32  return std::to_string(x);
33 }
34 
35 template <>
36 inline std::string GetString<float>(float x) {
37  std::ostringstream oss;
38  oss << std::setprecision(std::numeric_limits<float>::max_digits10) << x;
39  return oss.str();
40 }
41 
42 template <>
43 inline std::string GetString<double>(double x) {
44  std::ostringstream oss;
45  oss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
46  return oss.str();
47 }
48 
49 } // anonymous namespace
50 
51 namespace treelite {
52 
53 template <typename T>
54 ContiguousArray<T>::ContiguousArray()
55  : buffer_(nullptr), size_(0), capacity_(0), owned_buffer_(true) {}
56 
57 template <typename T>
58 ContiguousArray<T>::~ContiguousArray() {
59  if (buffer_ && owned_buffer_) {
60  std::free(buffer_);
61  }
62 }
63 
64 template <typename T>
65 ContiguousArray<T>::ContiguousArray(ContiguousArray&& other) noexcept
66  : buffer_(other.buffer_), size_(other.size_), capacity_(other.capacity_),
67  owned_buffer_(other.owned_buffer_) {
68  other.buffer_ = nullptr;
69  other.size_ = other.capacity_ = 0;
70 }
71 
72 template <typename T>
73 ContiguousArray<T>&
74 ContiguousArray<T>::operator=(ContiguousArray&& other) noexcept {
75  if (buffer_ && owned_buffer_) {
76  std::free(buffer_);
77  }
78  buffer_ = other.buffer_;
79  size_ = other.size_;
80  capacity_ = other.capacity_;
81  owned_buffer_ = other.owned_buffer_;
82  other.buffer_ = nullptr;
83  other.size_ = other.capacity_ = 0;
84  return *this;
85 }
86 
87 template <typename T>
88 inline ContiguousArray<T>
89 ContiguousArray<T>::Clone() const {
90  ContiguousArray clone;
91  if (buffer_) {
92  clone.buffer_ = static_cast<T*>(std::malloc(sizeof(T) * capacity_));
93  if (!clone.buffer_) {
94  throw Error("Could not allocate memory for the clone");
95  }
96  std::memcpy(clone.buffer_, buffer_, sizeof(T) * size_);
97  } else {
98  TREELITE_CHECK_EQ(size_, 0);
99  TREELITE_CHECK_EQ(capacity_, 0);
100  clone.buffer_ = nullptr;
101  }
102  clone.size_ = size_;
103  clone.capacity_ = capacity_;
104  clone.owned_buffer_ = true;
105  return clone;
106 }
107 
108 template <typename T>
109 inline void
110 ContiguousArray<T>::UseForeignBuffer(void* prealloc_buf, std::size_t size) {
111  if (buffer_ && owned_buffer_) {
112  std::free(buffer_);
113  }
114  buffer_ = static_cast<T*>(prealloc_buf);
115  size_ = size;
116  capacity_ = size;
117  owned_buffer_ = false;
118 }
119 
120 template <typename T>
121 inline T*
122 ContiguousArray<T>::Data() {
123  return buffer_;
124 }
125 
126 template <typename T>
127 inline const T*
128 ContiguousArray<T>::Data() const {
129  return buffer_;
130 }
131 
132 template <typename T>
133 inline T*
134 ContiguousArray<T>::End() {
135  return &buffer_[Size()];
136 }
137 
138 template <typename T>
139 inline const T*
140 ContiguousArray<T>::End() const {
141  return &buffer_[Size()];
142 }
143 
144 template <typename T>
145 inline T&
146 ContiguousArray<T>::Back() {
147  return buffer_[Size() - 1];
148 }
149 
150 template <typename T>
151 inline const T&
152 ContiguousArray<T>::Back() const {
153  return buffer_[Size() - 1];
154 }
155 
156 template <typename T>
157 inline std::size_t
158 ContiguousArray<T>::Size() const {
159  return size_;
160 }
161 
162 template <typename T>
163 inline bool
164 ContiguousArray<T>::Empty() const {
165  return (Size() == 0);
166 }
167 
168 template <typename T>
169 inline void
170 ContiguousArray<T>::Reserve(std::size_t newsize) {
171  if (!owned_buffer_) {
172  throw Error("Cannot resize when using a foreign buffer; clone first");
173  }
174  T* newbuf = static_cast<T*>(std::realloc(static_cast<void*>(buffer_), sizeof(T) * newsize));
175  if (!newbuf) {
176  throw Error("Could not expand buffer");
177  }
178  buffer_ = newbuf;
179  capacity_ = newsize;
180 }
181 
182 template <typename T>
183 inline void
184 ContiguousArray<T>::Resize(std::size_t newsize) {
185  if (!owned_buffer_) {
186  throw Error("Cannot resize when using a foreign buffer; clone first");
187  }
188  if (newsize > capacity_) {
189  std::size_t newcapacity = capacity_;
190  if (newcapacity == 0) {
191  newcapacity = 1;
192  }
193  while (newcapacity <= newsize) {
194  newcapacity *= 2;
195  }
196  Reserve(newcapacity);
197  }
198  size_ = newsize;
199 }
200 
201 template <typename T>
202 inline void
203 ContiguousArray<T>::Resize(std::size_t newsize, T t) {
204  if (!owned_buffer_) {
205  throw Error("Cannot resize when using a foreign buffer; clone first");
206  }
207  std::size_t oldsize = Size();
208  Resize(newsize);
209  for (std::size_t i = oldsize; i < newsize; ++i) {
210  buffer_[i] = t;
211  }
212 }
213 
214 template <typename T>
215 inline void
216 ContiguousArray<T>::Clear() {
217  if (!owned_buffer_) {
218  throw Error("Cannot clear when using a foreign buffer; clone first");
219  }
220  Resize(0);
221 }
222 
223 template <typename T>
224 inline void
225 ContiguousArray<T>::PushBack(T t) {
226  if (!owned_buffer_) {
227  throw Error("Cannot add element when using a foreign buffer; clone first");
228  }
229  if (size_ == capacity_) {
230  Reserve(capacity_ * 2);
231  }
232  buffer_[size_++] = t;
233 }
234 
235 template <typename T>
236 inline void
237 ContiguousArray<T>::Extend(const std::vector<T>& other) {
238  if (!owned_buffer_) {
239  throw Error("Cannot add elements when using a foreign buffer; clone first");
240  }
241  if (other.empty()) {
242  return; // appending an empty vector is a no-op
243  }
244  std::size_t newsize = size_ + other.size();
245  if (newsize > capacity_) {
246  std::size_t newcapacity = capacity_;
247  if (newcapacity == 0) {
248  newcapacity = 1;
249  }
250  while (newcapacity <= newsize) {
251  newcapacity *= 2;
252  }
253  Reserve(newcapacity);
254  }
255  std::memcpy(&buffer_[size_], static_cast<const void*>(other.data()), sizeof(T) * other.size());
256  size_ = newsize;
257 }
258 
259 template <typename T>
260 inline T&
261 ContiguousArray<T>::operator[](std::size_t idx) {
262  return buffer_[idx];
263 }
264 
265 template <typename T>
266 inline const T&
267 ContiguousArray<T>::operator[](std::size_t idx) const {
268  return buffer_[idx];
269 }
270 
271 template <typename T>
272 inline T&
273 ContiguousArray<T>::at(std::size_t idx) {
274  if (idx >= Size()) {
275  throw Error("nid out of range");
276  }
277  return buffer_[idx];
278 }
279 
280 template <typename T>
281 inline const T&
282 ContiguousArray<T>::at(std::size_t idx) const {
283  if (idx >= Size()) {
284  throw Error("nid out of range");
285  }
286  return buffer_[idx];
287 }
288 
289 template <typename T>
290 inline T&
291 ContiguousArray<T>::at(int idx) {
292  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
293  throw Error("nid out of range");
294  }
295  return buffer_[static_cast<std::size_t>(idx)];
296 }
297 
298 template <typename T>
299 inline const T&
300 ContiguousArray<T>::at(int idx) const {
301  if (idx < 0 || static_cast<std::size_t>(idx) >= Size()) {
302  throw Error("nid out of range");
303  }
304  return buffer_[static_cast<std::size_t>(idx)];
305 }
306 
307 template<typename Container>
308 inline std::vector<std::pair<std::string, std::string> >
309 ModelParam::InitAllowUnknown(const Container& kwargs) {
310  std::vector<std::pair<std::string, std::string>> unknowns;
311  for (const auto& e : kwargs) {
312  if (e.first == "pred_transform") {
313  std::strncpy(this->pred_transform, e.second.c_str(),
314  TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1);
315  this->pred_transform[TREELITE_MAX_PRED_TRANSFORM_LENGTH - 1] = '\0';
316  } else if (e.first == "sigmoid_alpha") {
317  this->sigmoid_alpha = std::stof(e.second, nullptr);
318  } else if (e.first == "ratio_c") {
319  this->ratio_c = std::stof(e.second, nullptr);
320  } else if (e.first == "global_bias") {
321  this->global_bias = std::stof(e.second, nullptr);
322  }
323  }
324  return unknowns;
325 }
326 
327 inline std::map<std::string, std::string>
328 ModelParam::__DICT__() const {
329  std::map<std::string, std::string> ret;
330  ret.emplace("pred_transform", std::string(this->pred_transform));
331  ret.emplace("sigmoid_alpha", GetString(this->sigmoid_alpha));
332  ret.emplace("ratio_c", GetString(this->ratio_c));
333  ret.emplace("global_bias", GetString(this->global_bias));
334  return ret;
335 }
336 
337 inline PyBufferFrame GetPyBufferFromArray(void* data, const char* format,
338  std::size_t itemsize, std::size_t nitem) {
339  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
340 }
341 
342 // Infer format string from data type
343 template <typename T>
344 inline const char* InferFormatString() {
345  switch (sizeof(T)) {
346  case 1:
347  return (std::is_unsigned<T>::value ? "=B" : "=b");
348  case 2:
349  return (std::is_unsigned<T>::value ? "=H" : "=h");
350  case 4:
351  if (std::is_integral<T>::value) {
352  return (std::is_unsigned<T>::value ? "=L" : "=l");
353  } else {
354  if (!std::is_floating_point<T>::value) {
355  throw Error("Could not infer format string");
356  }
357  return "=f";
358  }
359  case 8:
360  if (std::is_integral<T>::value) {
361  return (std::is_unsigned<T>::value ? "=Q" : "=q");
362  } else {
363  if (!std::is_floating_point<T>::value) {
364  throw Error("Could not infer format string");
365  }
366  return "=d";
367  }
368  default:
369  throw Error("Unrecognized type");
370  }
371  return nullptr;
372 }
373 
374 template <typename T>
375 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec, const char* format) {
376  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
377 }
378 
379 template <typename T>
380 inline PyBufferFrame GetPyBufferFromArray(ContiguousArray<T>* vec) {
381  static_assert(std::is_arithmetic<T>::value,
382  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
383  return GetPyBufferFromArray(vec, InferFormatString<T>());
384 }
385 
386 inline PyBufferFrame GetPyBufferFromScalar(void* data, const char* format, std::size_t itemsize) {
387  return GetPyBufferFromArray(data, format, itemsize, 1);
388 }
389 
390 template <typename T>
391 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) {
392  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
393  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
394 }
395 
396 inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) {
397  using T = std::underlying_type<TypeInfo>::type;
398  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
399 }
400 
401 inline PyBufferFrame GetPyBufferFromScalar(TaskType* scalar) {
402  using T = std::underlying_type<TaskType>::type;
403  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
404 }
405 
406 template <typename T>
407 inline PyBufferFrame GetPyBufferFromScalar(T* scalar) {
408  static_assert(std::is_arithmetic<T>::value,
409  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
410  "specify format string manually");
411  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
412 }
413 
414 template <typename T>
415 inline void InitArrayFromPyBuffer(ContiguousArray<T>* vec, PyBufferFrame frame) {
416  if (sizeof(T) != frame.itemsize) {
417  throw Error("Incorrect itemsize");
418  }
419  vec->UseForeignBuffer(frame.buf, frame.nitem);
420 }
421 
422 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) {
423  using T = std::underlying_type<TypeInfo>::type;
424  if (sizeof(T) != buffer.itemsize) {
425  throw Error("Incorrect itemsize");
426  }
427  if (buffer.nitem != 1) {
428  throw Error("nitem must be 1 for a scalar");
429  }
430  T* t = static_cast<T*>(buffer.buf);
431  *scalar = static_cast<TypeInfo>(*t);
432 }
433 
434 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame buffer) {
435  using T = std::underlying_type<TaskType>::type;
436  if (sizeof(T) != buffer.itemsize) {
437  throw Error("Incorrect itemsize");
438  }
439  if (buffer.nitem != 1) {
440  throw Error("nitem must be 1 for a scalar");
441  }
442  T* t = static_cast<T*>(buffer.buf);
443  *scalar = static_cast<TaskType>(*t);
444 }
445 
446 template <typename T>
447 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) {
448  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
449  if (sizeof(T) != buffer.itemsize) {
450  throw Error("Incorrect itemsize");
451  }
452  if (buffer.nitem != 1) {
453  throw Error("nitem must be 1 for a scalar");
454  }
455  T* t = static_cast<T*>(buffer.buf);
456  *scalar = *t;
457 }
458 
459 template <typename T>
460 inline void ReadScalarFromFile(T* scalar, FILE* fp) {
461  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
462  if (std::fread(scalar, sizeof(T), 1, fp) < 1) {
463  throw Error("Could not read a scalar");
464  }
465 }
466 
467 template <typename T>
468 inline void WriteScalarToFile(T* scalar, FILE* fp) {
469  static_assert(std::is_standard_layout<T>::value, "T must be in the standard layout");
470  if (std::fwrite(scalar, sizeof(T), 1, fp) < 1) {
471  throw Error("Could not write a scalar");
472  }
473 }
474 
475 template <typename T>
476 inline void ReadArrayFromFile(ContiguousArray<T>* vec, FILE* fp) {
477  uint64_t nelem;
478  if (std::fread(&nelem, sizeof(nelem), 1, fp) < 1) {
479  throw Error("Could not read the number of elements");
480  }
481  vec->Clear();
482  vec->Resize(nelem);
483  if (nelem == 0) {
484  return; // handle empty arrays
485  }
486  const auto nelem_size_t = static_cast<std::size_t>(nelem);
487  if (std::fread(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
488  throw Error("Could not read an array");
489  }
490 }
491 
492 template <typename T>
493 inline void WriteArrayToFile(ContiguousArray<T>* vec, FILE* fp) {
494  static_assert(sizeof(uint64_t) >= sizeof(size_t), "size_t too large");
495  const auto nelem = static_cast<uint64_t>(vec->Size());
496  if (std::fwrite(&nelem, sizeof(nelem), 1, fp) < 1) {
497  throw Error("Could not write the number of elements");
498  }
499  if (nelem == 0) {
500  return; // handle empty arrays
501  }
502  const auto nelem_size_t = vec->Size();
503  if (std::fwrite(vec->Data(), sizeof(T), nelem_size_t, fp) < nelem_size_t) {
504  throw Error("Could not write an array");
505  }
506 }
507 
508 inline void SkipOptFieldInFile(FILE* fp) {
509  uint16_t elem_size;
510  uint64_t nelem;
511  ReadScalarFromFile(&elem_size, fp);
512  ReadScalarFromFile(&nelem, fp);
513 
514  const uint64_t nbytes = elem_size * nelem;
515  TREELITE_CHECK_LE(nbytes, std::numeric_limits<long>::max()); // NOLINT
516  if (std::fseek(fp, static_cast<long>(nbytes), SEEK_CUR) != 0) { // NOLINT
517  throw Error("Reached end of file");
518  }
519 }
520 
521 template <typename ThresholdType, typename LeafOutputType>
522 Tree<ThresholdType, LeafOutputType>::Tree(bool use_opt_field)
523  : use_opt_field_(use_opt_field)
524 {}
525 
526 template <typename ThresholdType, typename LeafOutputType>
527 inline Tree<ThresholdType, LeafOutputType>
528 Tree<ThresholdType, LeafOutputType>::Clone() const {
529  Tree<ThresholdType, LeafOutputType> tree;
530  tree.num_nodes = num_nodes;
531  tree.nodes_ = nodes_.Clone();
532  tree.leaf_vector_ = leaf_vector_.Clone();
533  tree.leaf_vector_begin_ = leaf_vector_begin_.Clone();
534  tree.leaf_vector_end_ = leaf_vector_end_.Clone();
535  tree.matching_categories_ = matching_categories_.Clone();
536  tree.matching_categories_offset_ = matching_categories_offset_.Clone();
537  return tree;
538 }
539 
540 template <typename ThresholdType, typename LeafOutputType>
541 inline const char*
542 Tree<ThresholdType, LeafOutputType>::GetFormatStringForNode() {
543  if (std::is_same<ThresholdType, float>::value) {
544  return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}";
545  } else {
546  return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}";
547  }
548 }
549 
550 template <typename ThresholdType, typename LeafOutputType>
551 template <typename ScalarHandler, typename PrimitiveArrayHandler, typename CompositeArrayHandler>
552 inline void
553 Tree<ThresholdType, LeafOutputType>::SerializeTemplate(
554  ScalarHandler scalar_handler, PrimitiveArrayHandler primitive_array_handler,
555  CompositeArrayHandler composite_array_handler) {
556  scalar_handler(&num_nodes);
557  scalar_handler(&has_categorical_split_);
558  composite_array_handler(&nodes_, GetFormatStringForNode());
559  primitive_array_handler(&leaf_vector_);
560  primitive_array_handler(&leaf_vector_begin_);
561  primitive_array_handler(&leaf_vector_end_);
562  primitive_array_handler(&matching_categories_);
563  primitive_array_handler(&matching_categories_offset_);
564 
565  /* Extension slot 2: Per-tree optional fields -- to be added later */
566  num_opt_field_per_tree_ = 0;
567  scalar_handler(&num_opt_field_per_tree_);
568 
569  /* Extension slot 3: Per-node optional fields -- to be added later */
570  num_opt_field_per_node_ = 0;
571  scalar_handler(&num_opt_field_per_node_);
572 }
573 
574 template <typename ThresholdType, typename LeafOutputType>
575 template <typename ScalarHandler, typename ArrayHandler, typename SkipOptFieldHandlerFunc>
576 inline void
577 Tree<ThresholdType, LeafOutputType>::DeserializeTemplate(
578  ScalarHandler scalar_handler,
579  ArrayHandler array_handler,
580  SkipOptFieldHandlerFunc skip_opt_field_handler) {
581  scalar_handler(&num_nodes);
582  scalar_handler(&has_categorical_split_);
583  array_handler(&nodes_);
584  if (static_cast<std::size_t>(num_nodes) != nodes_.Size()) {
585  throw Error("Could not load the correct number of nodes");
586  }
587  array_handler(&leaf_vector_);
588  array_handler(&leaf_vector_begin_);
589  array_handler(&leaf_vector_end_);
590  array_handler(&matching_categories_);
591  array_handler(&matching_categories_offset_);
592 
593  if (use_opt_field_) {
594  /* Extension slot 2: Per-tree optional fields -- to be added later */
595  scalar_handler(&num_opt_field_per_tree_);
596  // Ignore extra fields; the input is likely from a later version of Treelite
597  for (int32_t i = 0; i < num_opt_field_per_tree_; ++i) {
598  skip_opt_field_handler();
599  }
600 
601  /* Extension slot 3: Per-node optional fields -- to be added later */
602  scalar_handler(&num_opt_field_per_node_);
603  // Ignore extra fields; the input is likely from a later version of Treelite
604  for (int32_t i = 0; i < num_opt_field_per_node_; ++i) {
605  skip_opt_field_handler();
606  }
607  } else {
608  // Legacy (version 2.4)
609  num_opt_field_per_tree_ = 0;
610  num_opt_field_per_node_ = 0;
611  }
612 }
613 
614 template <typename ThresholdType, typename LeafOutputType>
615 inline void
616 Tree<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
617  auto scalar_handler = [dest](auto* field) {
618  dest->push_back(GetPyBufferFromScalar(field));
619  };
620  auto primitive_array_handler = [dest](auto* field) {
621  dest->push_back(GetPyBufferFromArray(field));
622  };
623  auto composite_array_handler = [dest](auto* field, const char* format) {
624  dest->push_back(GetPyBufferFromArray(field, format));
625  };
626  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
627 }
628 
629 template <typename ThresholdType, typename LeafOutputType>
630 inline void
631 Tree<ThresholdType, LeafOutputType>::SerializeToFile(FILE* dest_fp) {
632  auto scalar_handler = [dest_fp](auto* field) {
633  WriteScalarToFile(field, dest_fp);
634  };
635  auto primitive_array_handler = [dest_fp](auto* field) {
636  WriteArrayToFile(field, dest_fp);
637  };
638  auto composite_array_handler = [dest_fp](auto* field, const char* format) {
639  WriteArrayToFile(field, dest_fp);
640  };
641  SerializeTemplate(scalar_handler, primitive_array_handler, composite_array_handler);
642 }
643 
644 template <typename ThresholdType, typename LeafOutputType>
645 inline std::vector<PyBufferFrame>::iterator
646 Tree<ThresholdType, LeafOutputType>::InitFromPyBuffer(std::vector<PyBufferFrame>::iterator it) {
647  std::vector<PyBufferFrame>::iterator new_it = it;
648  auto scalar_handler = [&new_it](auto* field) {
649  InitScalarFromPyBuffer(field, *(new_it++));
650  };
651  auto array_handler = [&new_it](auto* field) {
652  InitArrayFromPyBuffer(field, *(new_it++));
653  };
654  auto skip_opt_field_handler = [&new_it]() {
655  ++new_it;
656  };
657  DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
658  return new_it;
659 }
660 
661 template <typename ThresholdType, typename LeafOutputType>
662 inline void
663 Tree<ThresholdType, LeafOutputType>::DeserializeFromFile(FILE* src_fp) {
664  auto scalar_handler = [src_fp](auto* field) {
665  ReadScalarFromFile(field, src_fp);
666  };
667  auto array_handler = [src_fp](auto* field) {
668  ReadArrayFromFile(field, src_fp);
669  };
670  auto skip_opt_field_handler = [src_fp]() {
671  SkipOptFieldInFile(src_fp);
672  };
673  DeserializeTemplate(scalar_handler, array_handler, skip_opt_field_handler);
674 }
675 
676 template <typename ThresholdType, typename LeafOutputType>
678  std::memset(this, 0, sizeof(Node));
679  cleft_ = cright_ = -1;
680  sindex_ = 0;
681  info_.leaf_value = static_cast<LeafOutputType>(0);
682  info_.threshold = static_cast<ThresholdType>(0);
683  data_count_ = 0;
684  sum_hess_ = gain_ = 0.0;
685  data_count_present_ = sum_hess_present_ = gain_present_ = false;
686  categories_list_right_child_ = false;
687  split_type_ = SplitFeatureType::kNone;
688  cmp_ = Operator::kNone;
689 }
690 
691 template <typename ThresholdType, typename LeafOutputType>
692 inline int
694  int nd = num_nodes++;
695  if (nodes_.Size() != static_cast<std::size_t>(nd)) {
696  throw Error("Invariant violated: nodes_ contains incorrect number of nodes");
697  }
698  for (int nid = nd; nid < num_nodes; ++nid) {
699  leaf_vector_begin_.PushBack(0);
700  leaf_vector_end_.PushBack(0);
701  matching_categories_offset_.PushBack(matching_categories_offset_.Back());
702  nodes_.Resize(nodes_.Size() + 1);
703  nodes_.Back().Init();
704  }
705  return nd;
706 }
707 
708 template <typename ThresholdType, typename LeafOutputType>
709 inline void
711  num_nodes = 1;
712  has_categorical_split_ = false;
713  leaf_vector_.Clear();
714  leaf_vector_begin_.Resize(1, {});
715  leaf_vector_end_.Resize(1, {});
716  matching_categories_.Clear();
717  matching_categories_offset_.Resize(2, 0);
718  nodes_.Resize(1);
719  nodes_.at(0).Init();
720  SetLeaf(0, static_cast<LeafOutputType>(0));
721 }
722 
723 template <typename ThresholdType, typename LeafOutputType>
724 inline void
726  const int cleft = this->AllocNode();
727  const int cright = this->AllocNode();
728  nodes_.at(nid).cleft_ = cleft;
729  nodes_.at(nid).cright_ = cright;
730 }
731 
732 template <typename ThresholdType, typename LeafOutputType>
733 inline void
735  int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) {
736  Node& node = nodes_.at(nid);
737  if (split_index >= ((1U << 31U) - 1)) {
738  throw Error("split_index too big");
739  }
740  if (default_left) split_index |= (1U << 31U);
741  node.sindex_ = split_index;
742  (node.info_).threshold = threshold;
743  node.cmp_ = cmp;
744  node.split_type_ = SplitFeatureType::kNumerical;
745  node.categories_list_right_child_ = false;
746 }
747 
748 template <typename ThresholdType, typename LeafOutputType>
749 inline void
751  int nid, unsigned split_index, bool default_left,
752  const std::vector<uint32_t>& categories_list, bool categories_list_right_child) {
753  if (split_index >= ((1U << 31U) - 1)) {
754  throw Error("split_index too big");
755  }
756 
757  const std::size_t end_oft = matching_categories_offset_.Back();
758  const std::size_t new_end_oft = end_oft + categories_list.size();
759  if (end_oft != matching_categories_.Size()) {
760  throw Error("Invariant violated");
761  }
762  if (!std::all_of(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
763  [end_oft](std::size_t x) { return (x == end_oft); })) {
764  throw Error("Invariant violated");
765  }
766  // Hopefully we won't have to move any element as we add node_matching_categories for node nid
767  matching_categories_.Extend(categories_list);
768  if (new_end_oft != matching_categories_.Size()) {
769  throw Error("Invariant violated");
770  }
771  std::for_each(&matching_categories_offset_.at(nid + 1), matching_categories_offset_.End(),
772  [new_end_oft](std::size_t& x) { x = new_end_oft; });
773  if (!matching_categories_.Empty()) {
774  std::sort(&matching_categories_.at(end_oft), matching_categories_.End());
775  }
776 
777  Node& node = nodes_.at(nid);
778  if (default_left) split_index |= (1U << 31U);
779  node.sindex_ = split_index;
780  node.split_type_ = SplitFeatureType::kCategorical;
781  node.categories_list_right_child_ = categories_list_right_child;
782 
783  has_categorical_split_ = true;
784 }
785 
786 template <typename ThresholdType, typename LeafOutputType>
787 inline void
788 Tree<ThresholdType, LeafOutputType>::SetLeaf(int nid, LeafOutputType value) {
789  Node& node = nodes_.at(nid);
790  (node.info_).leaf_value = value;
791  node.cleft_ = -1;
792  node.cright_ = -1;
793  node.split_type_ = SplitFeatureType::kNone;
794 }
795 
796 template <typename ThresholdType, typename LeafOutputType>
797 inline void
799  int nid, const std::vector<LeafOutputType>& node_leaf_vector) {
800  std::size_t begin = leaf_vector_.Size();
801  std::size_t end = begin + node_leaf_vector.size();
802  leaf_vector_.Extend(node_leaf_vector);
803  leaf_vector_begin_[nid] = begin;
804  leaf_vector_end_[nid] = end;
805  Node &node = nodes_.at(nid);
806  node.cleft_ = -1;
807  node.cright_ = -1;
808  node.split_type_ = SplitFeatureType::kNone;
809 }
810 
811 template <typename ThresholdType, typename LeafOutputType>
812 inline std::unique_ptr<Model>
813 Model::Create() {
814  std::unique_ptr<Model> model = std::make_unique<ModelImpl<ThresholdType, LeafOutputType>>();
815  model->threshold_type_ = TypeToInfo<ThresholdType>();
816  model->leaf_output_type_ = TypeToInfo<LeafOutputType>();
817  return model;
818 }
819 
820 template <typename ThresholdType, typename LeafOutputType>
822  public:
823  inline static std::unique_ptr<Model> Dispatch() {
824  return Model::Create<ThresholdType, LeafOutputType>();
825  }
826 };
827 
828 inline std::unique_ptr<Model>
829 Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) {
830  return DispatchWithModelTypes<ModelCreateImpl>(threshold_type, leaf_output_type);
831 }
832 
833 template <typename ThresholdType, typename LeafOutputType>
835  public:
836  template <typename Func>
837  inline static auto Dispatch(Model* model, Func func) {
838  return func(*dynamic_cast<ModelImpl<ThresholdType, LeafOutputType>*>(model));
839  }
840 
841  template <typename Func>
842  inline static auto Dispatch(const Model* model, Func func) {
843  return func(*dynamic_cast<const ModelImpl<ThresholdType, LeafOutputType>*>(model));
844  }
845 };
846 
847 template <typename Func>
848 inline auto
849 Model::Dispatch(Func func) {
850  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
851 }
852 
853 template <typename Func>
854 inline auto
855 Model::Dispatch(Func func) const {
856  return DispatchWithModelTypes<ModelDispatchImpl>(threshold_type_, leaf_output_type_, this, func);
857 }
858 
859 template <typename HeaderPrimitiveFieldHandlerFunc>
860 inline void
861 Model::SerializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler) {
862  major_ver_ = TREELITE_VER_MAJOR;
863  minor_ver_ = TREELITE_VER_MINOR;
864  patch_ver_ = TREELITE_VER_PATCH;
865  header_primitive_field_handler(&major_ver_);
866  header_primitive_field_handler(&minor_ver_);
867  header_primitive_field_handler(&patch_ver_);
868  header_primitive_field_handler(&threshold_type_);
869  header_primitive_field_handler(&leaf_output_type_);
870 }
871 
872 template <typename HeaderPrimitiveFieldHandlerFunc>
873 inline void
874 Model::DeserializeTemplate(HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
875  int32_t& major_ver, int32_t& minor_ver, int32_t& patch_ver,
876  TypeInfo& threshold_type, TypeInfo& leaf_output_type) {
877  header_primitive_field_handler(&major_ver);
878  header_primitive_field_handler(&minor_ver);
879  header_primitive_field_handler(&patch_ver);
880  if (major_ver != TREELITE_VER_MAJOR && !(major_ver == 2 && minor_ver == 4)) {
881  std::ostringstream oss;
882  oss << "Cannot deserialize model from a different major Treelite version or "
883  << "a version before 2.4.0." << std::endl
884  << "Currently running Treelite version " << TREELITE_VER_MAJOR << "."
885  << TREELITE_VER_MINOR << "." << TREELITE_VER_PATCH << std::endl
886  << "The model checkpoint was generated from Treelite version " << major_ver << "."
887  << minor_ver << "." << patch_ver;
888  throw Error(oss.str());
889  }
890  header_primitive_field_handler(&threshold_type);
891  header_primitive_field_handler(&leaf_output_type);
892 }
893 
894 template <typename ThresholdType, typename LeafOutputType>
895 template <typename HeaderPrimitiveFieldHandlerFunc, typename HeaderCompositeFieldHandlerFunc,
896  typename TreeHandlerFunc>
897 inline void
899  HeaderPrimitiveFieldHandlerFunc header_primitive_field_handler,
900  HeaderCompositeFieldHandlerFunc header_composite_field_handler,
901  TreeHandlerFunc tree_handler) {
902  /* Header */
903  header_primitive_field_handler(&num_feature);
904  header_primitive_field_handler(&task_type);
905  header_primitive_field_handler(&average_tree_output);
906  header_composite_field_handler(&task_param, "T{=B=?xx=I=I}");
907  header_composite_field_handler(
908  &param, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f=f}");
909 
910  /* Extension Slot 1: Per-model optional fields -- to be added later */
911  num_opt_field_per_model_ = 0;
912  header_primitive_field_handler(&num_opt_field_per_model_);
913 
914  /* Body */
915  for (Tree<ThresholdType, LeafOutputType>& tree : trees) {
916  tree_handler(tree);
917  }
918 }
919 
920 template <typename ThresholdType, typename LeafOutputType>
921 template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc,
922  typename SkipOptFieldHandlerFunc>
923 inline void
925  std::size_t num_tree,
926  HeaderFieldHandlerFunc header_field_handler,
927  TreeHandlerFunc tree_handler,
928  SkipOptFieldHandlerFunc skip_opt_field_handler) {
929  /* Header */
930  header_field_handler(&num_feature);
931  header_field_handler(&task_type);
932  header_field_handler(&average_tree_output);
933  header_field_handler(&task_param);
934  header_field_handler(&param);
935 
936  /* Extension Slot 1: Per-model optional fields -- to be added later */
937  const bool use_opt_field = (major_ver_ >= 3);
938  if (use_opt_field) {
939  header_field_handler(&num_opt_field_per_model_);
940  // Ignore extra fields; the input is likely from a later version of Treelite
941  for (int32_t i = 0; i < num_opt_field_per_model_; ++i) {
942  skip_opt_field_handler();
943  }
944  } else {
945  // Legacy (version 2.4)
946  num_opt_field_per_model_ = 0;
947  }
948 
949  /* Body */
950  trees.clear();
951  for (std::size_t i = 0; i < num_tree; ++i) {
952  trees.emplace_back(use_opt_field);
953  tree_handler(trees.back());
954  }
955 }
956 
957 template <typename ThresholdType, typename LeafOutputType>
958 inline void
959 ModelImpl<ThresholdType, LeafOutputType>::GetPyBuffer(std::vector<PyBufferFrame>* dest) {
960  num_tree_ = static_cast<uint64_t>(this->trees.size());
961  auto header_primitive_field_handler = [dest](auto* field) {
962  dest->push_back(GetPyBufferFromScalar(field));
963  };
964  auto header_composite_field_handler = [dest](auto* field, const char* format) {
965  dest->push_back(GetPyBufferFromScalar(field, format));
966  };
967  auto tree_handler = [dest](Tree<ThresholdType, LeafOutputType>& tree) {
968  tree.GetPyBuffer(dest);
969  };
970  header_primitive_field_handler(&num_tree_);
971  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
972 }
973 
974 template <typename ThresholdType, typename LeafOutputType>
975 inline void
977  num_tree_ = static_cast<uint64_t>(this->trees.size());
978  auto header_primitive_field_handler = [dest_fp](auto* field) {
979  WriteScalarToFile(field, dest_fp);
980  };
981  auto header_composite_field_handler = [dest_fp](auto* field, const char* format) {
982  WriteScalarToFile(field, dest_fp);
983  };
984  auto tree_handler = [dest_fp](Tree<ThresholdType, LeafOutputType>& tree) {
985  tree.SerializeToFile(dest_fp);
986  };
987  header_primitive_field_handler(&num_tree_);
988  SerializeTemplate(header_primitive_field_handler, header_composite_field_handler, tree_handler);
989 }
990 
991 template <typename ThresholdType, typename LeafOutputType>
992 inline std::vector<PyBufferFrame>::iterator
994  std::vector<PyBufferFrame>::iterator it, std::size_t num_frame) {
995  std::vector<PyBufferFrame>::iterator new_it = it;
996  auto header_field_handler = [&new_it](auto* field) {
997  InitScalarFromPyBuffer(field, *(new_it++));
998  };
999 
1000  auto skip_opt_field_handler = [&new_it]() {
1001  ++new_it;
1002  };
1003 
1004  auto tree_handler = [&new_it](Tree<ThresholdType, LeafOutputType>& tree) {
1005  new_it = tree.InitFromPyBuffer(new_it);
1006  };
1007 
1008  if (major_ver_ == 2) {
1009  // From version 2.4 (legacy)
1010  // num_tree has to be inferred from the number of frames received
1011  constexpr std::size_t kNumFrameInHeader = 5;
1012  constexpr std::size_t kNumFramePerTree = 8;
1013  num_tree_ = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
1014  } else {
1015  // From version 3.x
1016  // num_tree is now explicitly stored
1017  header_field_handler(&num_tree_);
1018  }
1019 
1020  DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1021  TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1022 
1023  return new_it;
1024 }
1025 
1026 template <typename ThresholdType, typename LeafOutputType>
1027 inline void
1029  ReadScalarFromFile(&num_tree_, src_fp);
1030 
1031  auto header_field_handler = [src_fp](auto* field) {
1032  ReadScalarFromFile(field, src_fp);
1033  };
1034 
1035  auto skip_opt_field_handler = [src_fp]() {
1036  SkipOptFieldInFile(src_fp);
1037  };
1038 
1039  auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
1040  tree.DeserializeFromFile(src_fp);
1041  };
1042 
1043  DeserializeTemplate(num_tree_, header_field_handler, tree_handler, skip_opt_field_handler);
1044  TREELITE_CHECK_EQ(num_tree_, this->trees.size());
1045 }
1046 
1047 inline void InitParamAndCheck(ModelParam* param,
1048  const std::vector<std::pair<std::string, std::string>>& cfg) {
1049  auto unknown = param->InitAllowUnknown(cfg);
1050  if (!unknown.empty()) {
1051  std::ostringstream oss;
1052  for (const auto& kv : unknown) {
1053  oss << kv.first << ", ";
1054  }
1055  std::cerr << "\033[1;31mWarning: Unknown parameters found; "
1056  << "they have been ignored\u001B[0m: " << oss.str() << std::endl;
1057  }
1058 }
1059 
1060 } // namespace treelite
1061 #endif // TREELITE_TREE_IMPL_H_
void Init()
initialize the model with a single root node
Definition: tree_impl.h:710
Exception class that will be thrown by Treelite.
Definition: error.h:18
TaskType
Enum type representing the task type.
Definition: tree.h:107
tree node
Definition: tree.h:225
in-memory representation of a decision tree
Definition: tree.h:222
logging facility for Treelite
Exception class used throughout the Treelite codebase.
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:23
thin wrapper for tree ensemble model
Definition: tree.h:734
Operator
comparison operators
Definition: base.h:27