treelite
common.h
1 
7 #ifndef TREELITE_COMMON_H_
8 #define TREELITE_COMMON_H_
9 
10 #include <treelite/base.h>
11 #include <dmlc/logging.h>
12 #include <dmlc/json.h>
13 #include <dmlc/data.h>
14 #include <algorithm>
15 #include <vector>
16 #include <limits>
17 #include <string>
18 #include <memory>
19 #include <string>
20 #include <sstream>
21 #include <iterator>
22 #include <fstream>
23 #include <functional>
24 #include <limits>
25 #include <stdexcept>
26 #include <iomanip>
27 #include <cerrno>
28 #include <climits>
29 
30 namespace treelite {
31 namespace common {
32 
34 class Cloneable {
35  public:
36  virtual ~Cloneable() = default;
37  virtual Cloneable* clone() const = 0; // for copy operation
38  virtual Cloneable* move_clone() = 0; // for move operation
39 };
40 
42 #define CLONEABLE_BOILERPLATE(className) \
43  explicit className(const className& other) = default; \
44  explicit className(className&& other) = default; \
45  Cloneable* clone() const override { \
46  return new className(*this); \
47  } \
48  Cloneable* move_clone() override { \
49  return new className(std::move(*this)); \
50  }
51 
56 template <typename T>
58  public:
59  static_assert(std::is_base_of<Cloneable, T>::value,
60  "DeepCopyUniquePtr requires a Cloneable type");
61  ~DeepCopyUniquePtr() {}
62 
63  explicit DeepCopyUniquePtr(const T& other)
64  : ptr(dynamic_cast<T*>(other.clone())) {}
65  // downcasting is okay here because the other object is certainly of type T
66  explicit DeepCopyUniquePtr(T&& other)
67  : ptr(dynamic_cast<T*>(other.move_clone())) {}
68  explicit DeepCopyUniquePtr(const DeepCopyUniquePtr<T>& other)
69  : ptr(dynamic_cast<T*>(other.ptr->clone())) {}
70  explicit DeepCopyUniquePtr(DeepCopyUniquePtr<T>&& other)
71  : ptr(std::move(other.ptr)) {}
72 
73  inline T& operator*() {
74  return *ptr;
75  }
76  const inline T& operator*() const {
77  return *ptr;
78  }
79  T* operator->() {
80  return ptr.operator->();
81  }
82  const T* operator->() const {
83  return ptr.operator->();
84  }
85 
86  private:
87  std::unique_ptr<T> ptr;
88 };
89 
99 template<typename T, typename ...Args>
100 std::unique_ptr<T> make_unique(Args&& ...args) {
101  return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
102 }
103 
117 template <typename T>
118 inline T&& MoveUniquePtr(const std::unique_ptr<T>& ptr) {
119  return std::move(*ptr.get());
120 }
121 
125  public:
132  ArrayFormatter(size_t text_width, size_t indent, char delimiter = ',')
133  : oss_(), text_width_(text_width), indent_(indent), delimiter_(delimiter),
134  default_precision_(oss_.precision()), line_length_(indent),
135  is_empty_(true) {}
136 
141  template <typename T>
142  inline ArrayFormatter& operator<<(const T& e) {
143  if (is_empty_) {
144  is_empty_ = false;
145  oss_ << std::string(indent_, ' ');
146  }
147  std::ostringstream tmp;
148  tmp << std::setprecision(GetPrecision<T>()) << e << delimiter_ << " ";
149  const std::string token = tmp.str(); // token to be added to wrapped text
150  if (line_length_ + token.length() <= text_width_) {
151  oss_ << token;
152  line_length_ += token.length();
153  } else {
154  oss_ << "\n" << std::string(indent_, ' ') << token;
155  line_length_ = token.length() + indent_;
156  }
157  return *this;
158  }
159 
164  inline std::string str() {
165  return oss_.str();
166  }
167 
168  private:
169  std::ostringstream oss_; // string stream to store wrapped text
170  const size_t indent_; // indent level, to indent each line
171  const size_t text_width_; // maximum length of each line
172  const char delimiter_; // delimiter (defaults to comma)
173  const int default_precision_; // default precision used by string stream
174  size_t line_length_; // width of current line
175  bool is_empty_; // true if no entry has been added yet
176 
177  template <typename T>
178  inline int GetPrecision() {
179  return default_precision_;
180  }
181 };
182 
183 template <>
184 inline int ArrayFormatter::GetPrecision<float>() {
185  return std::numeric_limits<float>::digits10 + 2;
186 }
187 template <>
188 inline int ArrayFormatter::GetPrecision<double>() {
189  return std::numeric_limits<double>::digits10 + 2;
190 }
191 
199 inline std::string IndentMultiLineString(const std::string& str,
200  size_t indent = 2) {
201  std::ostringstream oss;
202  if (str[0] != '\n') {
203  oss << std::string(indent, ' ');
204  }
205  bool need_indent = false;
206  // one or more newlines will cause empty spaces to be inserted as indent
207  for (char c : str) { // assume UNIX-style line ending
208  if (c == '\n') {
209  need_indent = true;
210  } else if (need_indent) {
211  oss << std::string(indent, ' ');
212  need_indent = false;
213  }
214  oss << c;
215  }
216  return oss.str();
217 }
218 
228 template<class Iter, class T>
229 Iter binary_search(Iter begin, Iter end, const T& val) {
230  Iter i = std::lower_bound(begin, end, val);
231  if (i != end && !(val < *i)) {
232  return i; // found
233  } else {
234  return end; // not found
235  }
236 }
237 
244 template <typename T>
245 inline std::string ToStringHighPrecision(T value) {
246  std::ostringstream oss;
247  oss << std::setprecision(std::numeric_limits<T>::digits10 + 2) << value;
248  return oss.str();
249 }
250 
258 inline void WriteToFile(const std::string& filename,
259  const std::string& content) {
260  std::ofstream of(filename);
261  of << content;
262 }
263 
271 inline void TransformPushBack(std::vector<std::string>* p_dest,
272  const std::vector<std::string>& lines,
273  std::function<std::string(std::string)> func) {
274  auto& dest = *p_dest;
275  std::transform(lines.begin(), lines.end(), std::back_inserter(dest), func);
276 }
277 
284 template <typename T>
285 inline T TextToNumber(const std::string& str) {
286  static_assert(std::is_same<T, float>::value
287  || std::is_same<T, double>::value
288  || std::is_same<T, int>::value
289  || std::is_same<T, int8_t>::value
290  || std::is_same<T, uint32_t>::value,
291  "unsupported data type for TextToNumber; use float, double, "
292  "int, int8_t, or uint32_t.");
293 }
294 
295 template <>
296 inline float TextToNumber(const std::string& str) {
297  errno = 0;
298  char *endptr;
299  float val = std::strtof(str.c_str(), &endptr);
300  if (errno == ERANGE) {
301  LOG(FATAL) << "Range error while converting string to double";
302  } else if (errno != 0) {
303  LOG(FATAL) << "Unknown error";
304  } else if (*endptr != '\0') {
305  LOG(FATAL) << "String does not represent a valid floating-point number";
306  }
307  return val;
308 }
309 
310 template <>
311 inline double TextToNumber(const std::string& str) {
312  errno = 0;
313  char *endptr;
314  double val = std::strtod(str.c_str(), &endptr);
315  if (errno == ERANGE) {
316  LOG(FATAL) << "Range error while converting string to double";
317  } else if (errno != 0) {
318  LOG(FATAL) << "Unknown error";
319  } else if (*endptr != '\0') {
320  LOG(FATAL) << "String does not represent a valid floating-point number";
321  }
322  return val;
323 }
324 
325 template <>
326 inline int TextToNumber(const std::string& str) {
327  errno = 0;
328  char *endptr;
329  auto val = std::strtol(str.c_str(), &endptr, 10);
330  if (errno == ERANGE || val < INT_MIN || val > INT_MAX) {
331  LOG(FATAL) << "Range error while converting string to int";
332  } else if (errno != 0) {
333  LOG(FATAL) << "Unknown error";
334  } else if (*endptr != '\0') {
335  LOG(FATAL) << "String does not represent a valid integer";
336  }
337  return static_cast<int>(val);
338 }
339 
340 template <>
341 inline int8_t TextToNumber(const std::string& str) {
342  errno = 0;
343  char *endptr;
344  auto val = std::strtol(str.c_str(), &endptr, 10);
345  if (errno == ERANGE || val < INT8_MIN || val > INT8_MAX) {
346  LOG(FATAL) << "Range error while converting string to int8_t";
347  } else if (errno != 0) {
348  LOG(FATAL) << "Unknown error";
349  } else if (*endptr != '\0') {
350  LOG(FATAL) << "String does not represent a valid integer";
351  }
352  return static_cast<int8_t>(val);
353 }
354 
355 template <>
356 inline uint32_t TextToNumber(const std::string& str) {
357  errno = 0;
358  char *endptr;
359  auto val = std::strtoul(str.c_str(), &endptr, 10);
360  if (errno == ERANGE || val > UINT32_MAX) {
361  LOG(FATAL) << "Range error while converting string to uint32_t";
362  } else if (errno != 0) {
363  LOG(FATAL) << "Unknown error";
364  } else if (*endptr != '\0') {
365  LOG(FATAL) << "String does not represent a valid integer";
366  }
367  return static_cast<uint32_t>(val);
368 }
369 
377 template <typename T>
378 inline std::vector<T> TextToArray(const std::string& text, int num_entry) {
379  if (text.empty() && num_entry > 0) {
380  LOG(FATAL) << "Cannot convert empty text into array";
381  }
382  std::vector<T> array;
383  std::istringstream ss(text);
384  std::string token;
385  for (int i = 0; i < num_entry; ++i) {
386  std::getline(ss, token, ' ');
387  array.push_back(TextToNumber<T>(token));
388  }
389  return array;
390 }
391 
398 inline std::vector<std::string> Split(const std::string& text, char delim) {
399  std::vector<std::string> array;
400  std::istringstream ss(text);
401  std::string token;
402  while (std::getline(ss, token, delim)) {
403  array.push_back(token);
404  }
405  return array;
406 }
407 
416 inline bool CompareWithOp(treelite::tl_float lhs, treelite::Operator op,
417  treelite::tl_float rhs) {
418  switch (op) {
419  case treelite::Operator::kEQ: return lhs == rhs;
420  case treelite::Operator::kLT: return lhs < rhs;
421  case treelite::Operator::kLE: return lhs <= rhs;
422  case treelite::Operator::kGT: return lhs > rhs;
423  case treelite::Operator::kGE: return lhs >= rhs;
424  default: LOG(FATAL) << "operator undefined";
425  }
426 }
427 
428 } // namespace common
429 } // namespace treelite
430 #endif // TREELITE_COMMON_H_
ArrayFormatter & operator<<(const T &e)
add an entry (will use high precision for floating-point values)
Definition: common.h:142
std::string str()
obtain formatted text containing the rendered array
Definition: common.h:164
format array as text, wrapped to a given maximum text width. Uses high precision to render floating-p...
Definition: common.h:124
abstract interface for classes that can be cloned
Definition: common.h:34
a wrapper around std::unique_ptr that supports deep copying and moving.
Definition: common.h:57
ArrayFormatter(size_t text_width, size_t indent, char delimiter= ',')
constructor
Definition: common.h:132
double tl_float
float type to be used internally
Definition: base.h:17
defines configuration macros of treelite
Operator
comparison operators
Definition: base.h:23