treelite
common.h
1 
7 #ifndef TREELITE_COMMON_H_
8 #define TREELITE_COMMON_H_
9 
10 #include <algorithm>
11 #include <vector>
12 #include <limits>
13 #include <string>
14 #include <memory>
15 #include <string>
16 #include <sstream>
17 #include <iterator>
18 #include <fstream>
19 #include <functional>
20 #include <limits>
21 #include <utility>
22 #include <stdexcept>
23 #include <iomanip>
24 #include <cerrno>
25 #include <climits>
26 #include <treelite/base.h>
27 #include <dmlc/logging.h>
28 #include <dmlc/json.h>
29 #include <dmlc/data.h>
30 #include <fmt/format.h>
31 
32 namespace treelite {
33 namespace common {
34 
36 class Cloneable {
37  public:
38  virtual ~Cloneable() = default;
39  virtual Cloneable* clone() const = 0; // for copy operation
40  virtual Cloneable* move_clone() = 0; // for move operation
41 };
42 
44 #define CLONEABLE_BOILERPLATE(className) \
45  explicit className(const className& other) = default; \
46  explicit className(className&& other) = default; \
47  Cloneable* clone() const override { \
48  return new className(*this); \
49  } \
50  Cloneable* move_clone() override { \
51  return new className(std::move(*this)); \
52  }
53 
58 template <typename T>
60  public:
61  static_assert(std::is_base_of<Cloneable, T>::value,
62  "DeepCopyUniquePtr requires a Cloneable type");
63  ~DeepCopyUniquePtr() {}
64 
65  explicit DeepCopyUniquePtr(const T& other)
66  : ptr(dynamic_cast<T*>(other.clone())) {}
67  // downcasting is okay here because the other object is certainly of type T
68  explicit DeepCopyUniquePtr(T&& other)
69  : ptr(dynamic_cast<T*>(other.move_clone())) {}
70  explicit DeepCopyUniquePtr(const DeepCopyUniquePtr<T>& other)
71  : ptr(dynamic_cast<T*>(other.ptr->clone())) {}
72  explicit DeepCopyUniquePtr(DeepCopyUniquePtr<T>&& other)
73  : ptr(std::move(other.ptr)) {}
74 
75  inline T& operator*() {
76  return *ptr;
77  }
78  const inline T& operator*() const {
79  return *ptr;
80  }
81  T* operator->() {
82  return ptr.operator->();
83  }
84  const T* operator->() const {
85  return ptr.operator->();
86  }
87 
88  private:
89  std::unique_ptr<T> ptr;
90 };
91 
101 template<typename T, typename ...Args>
102 std::unique_ptr<T> make_unique(Args&& ...args) {
103  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
104 }
105 
119 template <typename T>
120 inline T&& MoveUniquePtr(const std::unique_ptr<T>& ptr) {
121  return std::move(*ptr.get());
122 }
123 
127  public:
134  ArrayFormatter(size_t text_width, size_t indent, char delimiter = ',')
135  : oss_(), text_width_(text_width), indent_(indent), delimiter_(delimiter),
136  default_precision_(oss_.precision()), line_length_(indent),
137  is_empty_(true) {}
138 
143  template <typename T>
144  inline ArrayFormatter& operator<<(const T& e) {
145  if (is_empty_) {
146  is_empty_ = false;
147  oss_ << std::string(indent_, ' ');
148  }
149  std::ostringstream tmp;
150  tmp << std::setprecision(GetPrecision<T>()) << e << delimiter_ << " ";
151  const std::string token = tmp.str(); // token to be added to wrapped text
152  if (line_length_ + token.length() <= text_width_) {
153  oss_ << token;
154  line_length_ += token.length();
155  } else {
156  oss_ << "\n" << std::string(indent_, ' ') << token;
157  line_length_ = token.length() + indent_;
158  }
159  return *this;
160  }
161 
166  inline std::string str() {
167  return oss_.str();
168  }
169 
170  private:
171  std::ostringstream oss_; // string stream to store wrapped text
172  const size_t indent_; // indent level, to indent each line
173  const size_t text_width_; // maximum length of each line
174  const char delimiter_; // delimiter (defaults to comma)
175  const int default_precision_; // default precision used by string stream
176  size_t line_length_; // width of current line
177  bool is_empty_; // true if no entry has been added yet
178 
179  template <typename T>
180  inline int GetPrecision() {
181  return default_precision_;
182  }
183 };
184 
185 template <>
186 inline int ArrayFormatter::GetPrecision<float>() {
187  return std::numeric_limits<float>::digits10 + 2;
188 }
189 template <>
190 inline int ArrayFormatter::GetPrecision<double>() {
191  return std::numeric_limits<double>::digits10 + 2;
192 }
193 
201 inline std::string IndentMultiLineString(const std::string& str,
202  size_t indent = 2) {
203  std::ostringstream oss;
204  if (str[0] != '\n') {
205  oss << std::string(indent, ' ');
206  }
207  bool need_indent = false;
208  // one or more newlines will cause empty spaces to be inserted as indent
209  for (char c : str) { // assume UNIX-style line ending
210  if (c == '\n') {
211  need_indent = true;
212  } else if (need_indent) {
213  oss << std::string(indent, ' ');
214  need_indent = false;
215  }
216  oss << c;
217  }
218  return oss.str();
219 }
220 
230 template<class Iter, class T>
231 Iter binary_search(Iter begin, Iter end, const T& val) {
232  Iter i = std::lower_bound(begin, end, val);
233  if (i != end && !(val < *i)) {
234  return i; // found
235  } else {
236  return end; // not found
237  }
238 }
239 
246 template <typename T>
247 inline std::string ToStringHighPrecision(T value) {
248  return fmt::format("{:.{}g}", value, std::numeric_limits<T>::digits10 + 2);
249 }
250 
258 inline void WriteToFile(const std::string& filename,
259  const std::string& content) {
260  std::ofstream of(filename);
261  of << content;
262 }
263 
269 inline void WriteToFile(const std::string& filename,
270  const std::vector<char>& content) {
271  std::ofstream of(filename, std::ios::out | std::ios::binary);
272  of.write(content.data(), content.size());
273 }
274 
282 inline void TransformPushBack(std::vector<std::string>* p_dest,
283  const std::vector<std::string>& lines,
284  std::function<std::string(std::string)> func) {
285  auto& dest = *p_dest;
286  std::transform(lines.begin(), lines.end(), std::back_inserter(dest), func);
287 }
288 
295 template <typename T>
296 inline T TextToNumber(const std::string& str) {
297  static_assert(std::is_same<T, float>::value
298  || std::is_same<T, double>::value
299  || std::is_same<T, int>::value
300  || std::is_same<T, int8_t>::value
301  || std::is_same<T, uint32_t>::value
302  || std::is_same<T, uint64_t>::value,
303  "unsupported data type for TextToNumber; use float, double, "
304  "int, int8_t, uint32_t, or uint64_t.");
305 }
306 
307 template <>
308 inline float TextToNumber(const std::string& str) {
309  errno = 0;
310  char *endptr;
311  float val = std::strtof(str.c_str(), &endptr);
312  if (errno == ERANGE) {
313  LOG(FATAL) << "Range error while converting string to double";
314  } else if (errno != 0) {
315  LOG(FATAL) << "Unknown error";
316  } else if (*endptr != '\0') {
317  LOG(FATAL) << "String does not represent a valid floating-point number";
318  }
319  return val;
320 }
321 
322 template <>
323 inline double TextToNumber(const std::string& str) {
324  errno = 0;
325  char *endptr;
326  double val = std::strtod(str.c_str(), &endptr);
327  if (errno == ERANGE) {
328  LOG(FATAL) << "Range error while converting string to double";
329  } else if (errno != 0) {
330  LOG(FATAL) << "Unknown error";
331  } else if (*endptr != '\0') {
332  LOG(FATAL) << "String does not represent a valid floating-point number";
333  }
334  return val;
335 }
336 
337 template <>
338 inline int TextToNumber(const std::string& str) {
339  errno = 0;
340  char *endptr;
341  auto val = std::strtol(str.c_str(), &endptr, 10);
342  if (errno == ERANGE || val < INT_MIN || val > INT_MAX) {
343  LOG(FATAL) << "Range error while converting string to int";
344  } else if (errno != 0) {
345  LOG(FATAL) << "Unknown error";
346  } else if (*endptr != '\0') {
347  LOG(FATAL) << "String does not represent a valid integer";
348  }
349  return static_cast<int>(val);
350 }
351 
352 template <>
353 inline int8_t TextToNumber(const std::string& str) {
354  errno = 0;
355  char *endptr;
356  auto val = std::strtol(str.c_str(), &endptr, 10);
357  if (errno == ERANGE || val < INT8_MIN || val > INT8_MAX) {
358  LOG(FATAL) << "Range error while converting string to int8_t";
359  } else if (errno != 0) {
360  LOG(FATAL) << "Unknown error";
361  } else if (*endptr != '\0') {
362  LOG(FATAL) << "String does not represent a valid integer";
363  }
364  return static_cast<int8_t>(val);
365 }
366 
367 template <>
368 inline uint32_t TextToNumber(const std::string& str) {
369  errno = 0;
370  char *endptr;
371  auto val = std::strtoul(str.c_str(), &endptr, 10);
372  if (errno == ERANGE || val > UINT32_MAX) {
373  LOG(FATAL) << "Range error while converting string to uint32_t";
374  } else if (errno != 0) {
375  LOG(FATAL) << "Unknown error";
376  } else if (*endptr != '\0') {
377  LOG(FATAL) << "String does not represent a valid integer";
378  }
379  return static_cast<uint32_t>(val);
380 }
381 
382 template <>
383 inline uint64_t TextToNumber(const std::string& str) {
384  errno = 0;
385  char *endptr;
386  auto val = std::strtoull(str.c_str(), &endptr, 10);
387  if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
388  LOG(FATAL) << "Range error while converting string to size_t";
389  } else if (errno != 0) {
390  LOG(FATAL) << "Unknown error";
391  } else if (*endptr != '\0') {
392  LOG(FATAL) << "String does not represent a valid integer";
393  }
394  return static_cast<uint64_t>(val);
395 }
396 
404 template <typename T>
405 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
406  if (text.empty() && num_entry > 0) {
407  LOG(FATAL) << "Cannot convert empty text into array";
408  }
409  std::vector<T> array;
410  std::istringstream ss(text);
411  std::string token;
412  for (int i = 0; i < num_entry; ++i) {
413  std::getline(ss, token, ' ');
414  array.push_back(TextToNumber<T>(token));
415  }
416  return array;
417 }
418 
425 inline std::vector<std::string> Split(const std::string& text, char delim) {
426  std::vector<std::string> array;
427  std::istringstream ss(text);
428  std::string token;
429  while (std::getline(ss, token, delim)) {
430  array.push_back(token);
431  }
432  return array;
433 }
434 
443 inline bool CompareWithOp(treelite::tl_float lhs, treelite::Operator op,
444  treelite::tl_float rhs) {
445  switch (op) {
446  case treelite::Operator::kEQ: return lhs == rhs;
447  case treelite::Operator::kLT: return lhs < rhs;
448  case treelite::Operator::kLE: return lhs <= rhs;
449  case treelite::Operator::kGT: return lhs > rhs;
450  case treelite::Operator::kGE: return lhs >= rhs;
451  default:
452  LOG(FATAL) << "operator undefined";
453  return false;
454  }
455 }
456 
457 } // namespace common
458 } // namespace treelite
459 #endif // TREELITE_COMMON_H_
ArrayFormatter & operator<<(const T &e)
add an entry (will use high precision for floating-point values)
Definition: common.h:144
std::string str()
obtain formatted text containing the rendered array
Definition: common.h:166
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:126
abstract interface for classes that can be cloned
Definition: common.h:36
a wrapper around std::unique_ptr that supports deep copying and moving.
Definition: common.h:59
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
ArrayFormatter(size_t text_width, size_t indent, char delimiter=',')
constructor
Definition: common.h:134
Operator
comparison operators
Definition: base.h:23