treelite
serializer.h
Go to the documentation of this file.
1 
8 #ifndef TREELITE_DETAIL_SERIALIZER_H_
9 #define TREELITE_DETAIL_SERIALIZER_H_
10 
12 #include <treelite/enum/operator.h>
15 #include <treelite/enum/typeinfo.h>
16 #include <treelite/logging.h>
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <istream>
22 #include <limits>
23 #include <ostream>
24 #include <string>
25 #include <type_traits>
26 #include <utility>
27 
29 
31  void* data, char const* format, std::size_t itemsize, std::size_t nitem) {
32  return PyBufferFrame{data, const_cast<char*>(format), itemsize, nitem};
33 }
34 
35 // Infer format string from data type
36 template <typename T>
37 inline char const* InferFormatString() {
38  switch (sizeof(T)) {
39  case 1:
40  return (std::is_unsigned_v<T> ? "=B" : "=b");
41  case 2:
42  return (std::is_unsigned_v<T> ? "=H" : "=h");
43  case 4:
44  if (std::is_integral_v<T>) {
45  return (std::is_unsigned_v<T> ? "=L" : "=l");
46  } else {
47  TREELITE_CHECK(std::is_floating_point_v<T>) << "Could not infer format string";
48  return "=f";
49  }
50  case 8:
51  if (std::is_integral_v<T>) {
52  return (std::is_unsigned_v<T> ? "=Q" : "=q");
53  } else {
54  TREELITE_CHECK(std::is_floating_point_v<T>) << "Could not infer format string";
55  return "=d";
56  }
57  default:
58  TREELITE_LOG(FATAL) << "Unrecognized type";
59  }
60  return nullptr;
61 }
62 
63 template <typename T>
65  return GetPyBufferFromArray(static_cast<void*>(vec->Data()), format, sizeof(T), vec->Size());
66 }
67 
68 template <typename T>
70  static_assert(std::is_arithmetic_v<T> || std::is_enum_v<T>,
71  "Use GetPyBufferFromArray(vec, format) for composite types; specify format string manually");
72  return GetPyBufferFromArray(vec, InferFormatString<T>());
73 }
74 
75 inline PyBufferFrame GetPyBufferFromScalar(void* data, char const* format, std::size_t itemsize) {
76  return GetPyBufferFromArray(data, format, itemsize, 1);
77 }
78 
79 inline PyBufferFrame GetPyBufferFromString(std::string* str) {
80  return GetPyBufferFromArray(str->data(), "=c", 1, str->length());
81 }
82 
83 template <typename T>
84 inline PyBufferFrame GetPyBufferFromScalar(T* scalar, char const* format) {
85  static_assert(std::is_standard_layout_v<T>, "T must be in the standard layout");
86  return GetPyBufferFromScalar(static_cast<void*>(scalar), format, sizeof(T));
87 }
88 
90  using T = std::underlying_type_t<TypeInfo>;
91  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
92 }
93 
95  using T = std::underlying_type_t<TaskType>;
96  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
97 }
98 
100  using T = std::underlying_type_t<TreeNodeType>;
101  return GetPyBufferFromScalar(reinterpret_cast<T*>(scalar), InferFormatString<T>());
102 }
103 
104 template <typename T>
106  static_assert(std::is_arithmetic_v<T> || std::is_enum_v<T>,
107  "Use GetPyBufferFromScalar(scalar, format) for composite types; "
108  "specify format string manually");
109  return GetPyBufferFromScalar(scalar, InferFormatString<T>());
110 }
111 
112 template <typename T>
114  TREELITE_CHECK_EQ(sizeof(T), frame.itemsize) << "Incorrect itemsize";
115  vec->UseForeignBuffer(frame.buf, frame.nitem);
116 }
117 
118 template <typename T>
120  TREELITE_CHECK_EQ(sizeof(T), frame.itemsize) << "Incorrect itemsize";
121  ContiguousArray<T> new_vec;
122  new_vec.UseForeignBuffer(frame.buf, frame.nitem);
123  *vec = std::move(new_vec);
124 }
125 
126 inline void InitStringFromPyBuffer(std::string* str, PyBufferFrame frame) {
127  TREELITE_CHECK_EQ(sizeof(char), frame.itemsize) << "Incorrect itemsize";
128  *str = std::string(static_cast<char*>(frame.buf), frame.nitem);
129 }
130 
131 inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame frame) {
132  using T = std::underlying_type_t<TypeInfo>;
133  TREELITE_CHECK_EQ(sizeof(T), frame.itemsize) << "Incorrect itemsize";
134  TREELITE_CHECK_EQ(frame.nitem, 1) << "nitem must be 1 for a scalar";
135  T* t = static_cast<T*>(frame.buf);
136  *scalar = static_cast<TypeInfo>(*t);
137 }
138 
139 inline void InitScalarFromPyBuffer(TaskType* scalar, PyBufferFrame frame) {
140  using T = std::underlying_type_t<TaskType>;
141  TREELITE_CHECK_EQ(sizeof(T), frame.itemsize) << "Incorrect itemsize";
142  TREELITE_CHECK_EQ(frame.nitem, 1) << "nitem must be 1 for a scalar";
143  T* t = static_cast<T*>(frame.buf);
144  *scalar = static_cast<TaskType>(*t);
145 }
146 
147 template <typename T>
148 inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame frame) {
149  static_assert(std::is_standard_layout_v<T>, "T must be in the standard layout");
150  TREELITE_CHECK_EQ(sizeof(T), frame.itemsize) << "Incorrect itemsize";
151  TREELITE_CHECK_EQ(frame.nitem, 1) << "nitem must be 1 for a scalar";
152  T* t = static_cast<T*>(frame.buf);
153  *scalar = *t;
154 }
155 
156 template <typename T>
157 inline void ReadScalarFromStream(T* scalar, std::istream& is) {
158  static_assert(std::is_standard_layout_v<T>, "T must be in the standard layout");
159  is.read(reinterpret_cast<char*>(scalar), sizeof(T));
160 }
161 
162 template <typename T>
163 inline void WriteScalarToStream(T* scalar, std::ostream& os) {
164  static_assert(std::is_standard_layout_v<T>, "T must be in the standard layout");
165  os.write(reinterpret_cast<char const*>(scalar), sizeof(T));
166 }
167 
168 template <typename T>
169 inline void ReadArrayFromStream(ContiguousArray<T>* vec, std::istream& is) {
170  std::uint64_t nelem;
171  is.read(reinterpret_cast<char*>(&nelem), sizeof(nelem));
172  vec->Clear();
173  vec->Resize(nelem);
174  if (nelem == 0) {
175  return; // handle empty arrays
176  }
177  is.read(reinterpret_cast<char*>(vec->Data()), sizeof(T) * nelem);
178 }
179 
180 template <typename T>
181 inline void WriteArrayToStream(ContiguousArray<T>* vec, std::ostream& os) {
182  static_assert(sizeof(std::uint64_t) >= sizeof(std::size_t), "size_t too large");
183  auto const nelem = static_cast<std::uint64_t>(vec->Size());
184  os.write(reinterpret_cast<char const*>(&nelem), sizeof(nelem));
185  if (nelem == 0) {
186  return; // handle empty arrays
187  }
188  os.write(reinterpret_cast<char const*>(vec->Data()), sizeof(T) * vec->Size());
189 }
190 
191 inline void ReadStringFromStream(std::string* str, std::istream& is) {
192  std::uint64_t str_len;
193  is.read(reinterpret_cast<char*>(&str_len), sizeof(str_len));
194  if (str_len == 0) {
195  return; // handle empty string
196  }
197  *str = std::string(str_len, '\0');
198  is.read(str->data(), sizeof(char) * str_len);
199 }
200 
201 inline void WriteStringToStream(std::string* str, std::ostream& os) {
202  static_assert(sizeof(std::uint64_t) >= sizeof(std::size_t), "size_t too large");
203  auto const str_len = static_cast<std::uint64_t>(str->length());
204  os.write(reinterpret_cast<char const*>(&str_len), sizeof(str_len));
205  if (str_len == 0) {
206  return; // handle empty string
207  }
208  os.write(str->data(), sizeof(char) * str->length());
209 }
210 
211 inline void SkipOptionalFieldInStream(std::istream& is) {
212  std::string field_name;
213  ReadStringFromStream(&field_name, is);
214 
215  std::uint64_t elem_size, nelem;
216  ReadScalarFromStream(&elem_size, is);
217  ReadScalarFromStream(&nelem, is);
218 
219  std::uint64_t const nbytes = elem_size * nelem;
220  TREELITE_CHECK_LE(nbytes, std::numeric_limits<std::streamoff>::max()); // NOLINT
221  is.seekg(static_cast<std::streamoff>(nbytes), std::ios::cur);
222 }
223 
224 } // namespace treelite::detail::serializer
225 
226 #endif // TREELITE_DETAIL_SERIALIZER_H_
Definition: contiguous_array.h:17
T * Data()
Definition: contiguous_array.h:108
std::size_t Size() const
Definition: contiguous_array.h:138
void Resize(std::size_t newsize)
Definition: contiguous_array.h:157
void UseForeignBuffer(void *prealloc_buf, std::size_t size)
Definition: contiguous_array.h:97
void Clear()
Definition: contiguous_array.h:182
A simple array container, with owned or non-owned (externally allocated) buffer.
size_t nitem
Definition: c_api.h:57
char * format
Definition: c_api.h:55
size_t itemsize
Definition: c_api.h:56
logging facility for Treelite
#define TREELITE_CHECK_LE(x, y)
Definition: logging.h:75
#define TREELITE_LOG(severity)
Definition: logging.h:84
#define TREELITE_CHECK(x)
Definition: logging.h:70
#define TREELITE_CHECK_EQ(x, y)
Definition: logging.h:77
Definition: serializer.h:28
void InitArrayFromPyBuffer(ContiguousArray< T > *vec, PyBufferFrame frame)
Definition: serializer.h:113
void InitScalarFromPyBuffer(TypeInfo *scalar, PyBufferFrame frame)
Definition: serializer.h:131
PyBufferFrame GetPyBufferFromScalar(void *data, char const *format, std::size_t itemsize)
Definition: serializer.h:75
void WriteStringToStream(std::string *str, std::ostream &os)
Definition: serializer.h:201
void WriteScalarToStream(T *scalar, std::ostream &os)
Definition: serializer.h:163
PyBufferFrame GetPyBufferFromArray(void *data, char const *format, std::size_t itemsize, std::size_t nitem)
Definition: serializer.h:30
void InitArrayFromPyBufferWithCopy(ContiguousArray< T > *vec, PyBufferFrame frame)
Definition: serializer.h:119
void ReadScalarFromStream(T *scalar, std::istream &is)
Definition: serializer.h:157
PyBufferFrame GetPyBufferFromString(std::string *str)
Definition: serializer.h:79
void InitStringFromPyBuffer(std::string *str, PyBufferFrame frame)
Definition: serializer.h:126
void ReadStringFromStream(std::string *str, std::istream &is)
Definition: serializer.h:191
void WriteArrayToStream(ContiguousArray< T > *vec, std::ostream &os)
Definition: serializer.h:181
void ReadArrayFromStream(ContiguousArray< T > *vec, std::istream &is)
Definition: serializer.h:169
char const * InferFormatString()
Definition: serializer.h:37
void SkipOptionalFieldInStream(std::istream &is)
Definition: serializer.h:211
TreelitePyBufferFrame PyBufferFrame
Definition: pybuffer_frame.h:18
TypeInfo
Types used by thresholds and leaf outputs.
Definition: typeinfo.h:21
TaskType
Enum type representing the task type.
Definition: task_type.h:19
TreeNodeType
Tree node type.
Definition: tree_node_type.h:17
Define enum type Operator.
Data structure to enable zero-copy exchange in Python.
Define enum type TaskType.
Define enum type NodeType.
Defines enum type TypeInfo.