treelite
common.h
1 
7 #ifndef TREELITE_COMMON_H_
8 #define TREELITE_COMMON_H_
9 
10 #include <treelite/base.h>
11 #include <dmlc/logging.h>
12 #include <dmlc/json.h>
13 #include <dmlc/data.h>
14 #include <algorithm>
15 #include <vector>
16 #include <limits>
17 #include <string>
18 #include <memory>
19 #include <string>
20 #include <sstream>
21 #include <iterator>
22 #include <fstream>
23 #include <functional>
24 #include <limits>
25 #include <utility>
26 #include <stdexcept>
27 #include <iomanip>
28 #include <cerrno>
29 #include <climits>
30 
31 namespace treelite {
32 namespace common {
33 
35 class Cloneable {
36  public:
37  virtual ~Cloneable() = default;
38  virtual Cloneable* clone() const = 0; // for copy operation
39  virtual Cloneable* move_clone() = 0; // for move operation
40 };
41 
43 #define CLONEABLE_BOILERPLATE(className) \
44  explicit className(const className& other) = default; \
45  explicit className(className&& other) = default; \
46  Cloneable* clone() const override { \
47  return new className(*this); \
48  } \
49  Cloneable* move_clone() override { \
50  return new className(std::move(*this)); \
51  }
52 
57 template <typename T>
59  public:
60  static_assert(std::is_base_of<Cloneable, T>::value,
61  "DeepCopyUniquePtr requires a Cloneable type");
62  ~DeepCopyUniquePtr() {}
63 
64  explicit DeepCopyUniquePtr(const T& other)
65  : ptr(dynamic_cast<T*>(other.clone())) {}
66  // downcasting is okay here because the other object is certainly of type T
67  explicit DeepCopyUniquePtr(T&& other)
68  : ptr(dynamic_cast<T*>(other.move_clone())) {}
69  explicit DeepCopyUniquePtr(const DeepCopyUniquePtr<T>& other)
70  : ptr(dynamic_cast<T*>(other.ptr->clone())) {}
71  explicit DeepCopyUniquePtr(DeepCopyUniquePtr<T>&& other)
72  : ptr(std::move(other.ptr)) {}
73 
74  inline T& operator*() {
75  return *ptr;
76  }
77  const inline T& operator*() const {
78  return *ptr;
79  }
80  T* operator->() {
81  return ptr.operator->();
82  }
83  const T* operator->() const {
84  return ptr.operator->();
85  }
86 
87  private:
88  std::unique_ptr<T> ptr;
89 };
90 
100 template<typename T, typename ...Args>
101 std::unique_ptr<T> make_unique(Args&& ...args) {
102  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
103 }
104 
118 template <typename T>
119 inline T&& MoveUniquePtr(const std::unique_ptr<T>& ptr) {
120  return std::move(*ptr.get());
121 }
122 
126  public:
133  ArrayFormatter(size_t text_width, size_t indent, char delimiter = ',')
134  : oss_(), text_width_(text_width), indent_(indent), delimiter_(delimiter),
135  default_precision_(oss_.precision()), line_length_(indent),
136  is_empty_(true) {}
137 
142  template <typename T>
143  inline ArrayFormatter& operator<<(const T& e) {
144  if (is_empty_) {
145  is_empty_ = false;
146  oss_ << std::string(indent_, ' ');
147  }
148  std::ostringstream tmp;
149  tmp << std::setprecision(GetPrecision<T>()) << e << delimiter_ << " ";
150  const std::string token = tmp.str(); // token to be added to wrapped text
151  if (line_length_ + token.length() <= text_width_) {
152  oss_ << token;
153  line_length_ += token.length();
154  } else {
155  oss_ << "\n" << std::string(indent_, ' ') << token;
156  line_length_ = token.length() + indent_;
157  }
158  return *this;
159  }
160 
165  inline std::string str() {
166  return oss_.str();
167  }
168 
169  private:
170  std::ostringstream oss_; // string stream to store wrapped text
171  const size_t indent_; // indent level, to indent each line
172  const size_t text_width_; // maximum length of each line
173  const char delimiter_; // delimiter (defaults to comma)
174  const int default_precision_; // default precision used by string stream
175  size_t line_length_; // width of current line
176  bool is_empty_; // true if no entry has been added yet
177 
178  template <typename T>
179  inline int GetPrecision() {
180  return default_precision_;
181  }
182 };
183 
184 template <>
185 inline int ArrayFormatter::GetPrecision<float>() {
186  return std::numeric_limits<float>::digits10 + 2;
187 }
188 template <>
189 inline int ArrayFormatter::GetPrecision<double>() {
190  return std::numeric_limits<double>::digits10 + 2;
191 }
192 
200 inline std::string IndentMultiLineString(const std::string& str,
201  size_t indent = 2) {
202  std::ostringstream oss;
203  if (str[0] != '\n') {
204  oss << std::string(indent, ' ');
205  }
206  bool need_indent = false;
207  // one or more newlines will cause empty spaces to be inserted as indent
208  for (char c : str) { // assume UNIX-style line ending
209  if (c == '\n') {
210  need_indent = true;
211  } else if (need_indent) {
212  oss << std::string(indent, ' ');
213  need_indent = false;
214  }
215  oss << c;
216  }
217  return oss.str();
218 }
219 
229 template<class Iter, class T>
230 Iter binary_search(Iter begin, Iter end, const T& val) {
231  Iter i = std::lower_bound(begin, end, val);
232  if (i != end && !(val < *i)) {
233  return i; // found
234  } else {
235  return end; // not found
236  }
237 }
238 
245 template <typename T>
246 inline std::string ToStringHighPrecision(T value) {
247  std::ostringstream oss;
248  oss << std::setprecision(std::numeric_limits<T>::digits10 + 2) << value;
249  return oss.str();
250 }
251 
259 inline void WriteToFile(const std::string& filename,
260  const std::string& content) {
261  std::ofstream of(filename);
262  of << content;
263 }
264 
272 inline void TransformPushBack(std::vector<std::string>* p_dest,
273  const std::vector<std::string>& lines,
274  std::function<std::string(std::string)> func) {
275  auto& dest = *p_dest;
276  std::transform(lines.begin(), lines.end(), std::back_inserter(dest), func);
277 }
278 
285 template <typename T>
286 inline T TextToNumber(const std::string& str) {
287  static_assert(std::is_same<T, float>::value
288  || std::is_same<T, double>::value
289  || std::is_same<T, int>::value
290  || std::is_same<T, int8_t>::value
291  || std::is_same<T, uint32_t>::value
292  || std::is_same<T, uint64_t>::value,
293  "unsupported data type for TextToNumber; use float, double, "
294  "int, int8_t, uint32_t, or uint64_t.");
295 }
296 
297 template <>
298 inline float TextToNumber(const std::string& str) {
299  errno = 0;
300  char *endptr;
301  float val = std::strtof(str.c_str(), &endptr);
302  if (errno == ERANGE) {
303  LOG(FATAL) << "Range error while converting string to double";
304  } else if (errno != 0) {
305  LOG(FATAL) << "Unknown error";
306  } else if (*endptr != '\0') {
307  LOG(FATAL) << "String does not represent a valid floating-point number";
308  }
309  return val;
310 }
311 
312 template <>
313 inline double TextToNumber(const std::string& str) {
314  errno = 0;
315  char *endptr;
316  double val = std::strtod(str.c_str(), &endptr);
317  if (errno == ERANGE) {
318  LOG(FATAL) << "Range error while converting string to double";
319  } else if (errno != 0) {
320  LOG(FATAL) << "Unknown error";
321  } else if (*endptr != '\0') {
322  LOG(FATAL) << "String does not represent a valid floating-point number";
323  }
324  return val;
325 }
326 
327 template <>
328 inline int TextToNumber(const std::string& str) {
329  errno = 0;
330  char *endptr;
331  auto val = std::strtol(str.c_str(), &endptr, 10);
332  if (errno == ERANGE || val < INT_MIN || val > INT_MAX) {
333  LOG(FATAL) << "Range error while converting string to int";
334  } else if (errno != 0) {
335  LOG(FATAL) << "Unknown error";
336  } else if (*endptr != '\0') {
337  LOG(FATAL) << "String does not represent a valid integer";
338  }
339  return static_cast<int>(val);
340 }
341 
342 template <>
343 inline int8_t TextToNumber(const std::string& str) {
344  errno = 0;
345  char *endptr;
346  auto val = std::strtol(str.c_str(), &endptr, 10);
347  if (errno == ERANGE || val < INT8_MIN || val > INT8_MAX) {
348  LOG(FATAL) << "Range error while converting string to int8_t";
349  } else if (errno != 0) {
350  LOG(FATAL) << "Unknown error";
351  } else if (*endptr != '\0') {
352  LOG(FATAL) << "String does not represent a valid integer";
353  }
354  return static_cast<int8_t>(val);
355 }
356 
357 template <>
358 inline uint32_t TextToNumber(const std::string& str) {
359  errno = 0;
360  char *endptr;
361  auto val = std::strtoul(str.c_str(), &endptr, 10);
362  if (errno == ERANGE || val > UINT32_MAX) {
363  LOG(FATAL) << "Range error while converting string to uint32_t";
364  } else if (errno != 0) {
365  LOG(FATAL) << "Unknown error";
366  } else if (*endptr != '\0') {
367  LOG(FATAL) << "String does not represent a valid integer";
368  }
369  return static_cast<uint32_t>(val);
370 }
371 
372 template <>
373 inline uint64_t TextToNumber(const std::string& str) {
374  errno = 0;
375  char *endptr;
376  auto val = std::strtoull(str.c_str(), &endptr, 10);
377  if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
378  LOG(FATAL) << "Range error while converting string to size_t";
379  } else if (errno != 0) {
380  LOG(FATAL) << "Unknown error";
381  } else if (*endptr != '\0') {
382  LOG(FATAL) << "String does not represent a valid integer";
383  }
384  return static_cast<uint64_t>(val);
385 }
386 
394 template <typename T>
395 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
396  if (text.empty() && num_entry > 0) {
397  LOG(FATAL) << "Cannot convert empty text into array";
398  }
399  std::vector<T> array;
400  std::istringstream ss(text);
401  std::string token;
402  for (int i = 0; i < num_entry; ++i) {
403  std::getline(ss, token, ' ');
404  array.push_back(TextToNumber<T>(token));
405  }
406  return array;
407 }
408 
415 inline std::vector<std::string> Split(const std::string& text, char delim) {
416  std::vector<std::string> array;
417  std::istringstream ss(text);
418  std::string token;
419  while (std::getline(ss, token, delim)) {
420  array.push_back(token);
421  }
422  return array;
423 }
424 
433 inline bool CompareWithOp(treelite::tl_float lhs, treelite::Operator op,
434  treelite::tl_float rhs) {
435  switch (op) {
436  case treelite::Operator::kEQ: return lhs == rhs;
437  case treelite::Operator::kLT: return lhs < rhs;
438  case treelite::Operator::kLE: return lhs <= rhs;
439  case treelite::Operator::kGT: return lhs > rhs;
440  case treelite::Operator::kGE: return lhs >= rhs;
441  default: LOG(FATAL) << "operator undefined";
442  }
443 }
444 
445 } // namespace common
446 } // namespace treelite
447 #endif // TREELITE_COMMON_H_
ArrayFormatter & operator<<(const T &e)
add an entry (will use high precision for floating-point values)
Definition: common.h:143
std::string str()
obtain formatted text containing the rendered array
Definition: common.h:165
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:125
abstract interface for classes that can be cloned
Definition: common.h:35
a wrapper around std::unique_ptr that supports deep copying and moving.
Definition: common.h:58
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
ArrayFormatter(size_t text_width, size_t indent, char delimiter=',')
constructor
Definition: common.h:133
Operator
comparison operators
Definition: base.h:23