7 #ifndef TREELITE_COMMON_H_ 8 #define TREELITE_COMMON_H_ 27 #include <dmlc/logging.h> 28 #include <dmlc/json.h> 29 #include <dmlc/data.h> 30 #include <fmt/format.h> 44 #define CLONEABLE_BOILERPLATE(className) \ 45 explicit className(const className& other) = default; \ 46 explicit className(className&& other) = default; \ 47 Cloneable* clone() const override { \ 48 return new className(*this); \ 50 Cloneable* move_clone() override { \ 51 return new className(std::move(*this)); \ 61 static_assert(std::is_base_of<Cloneable, T>::value,
62 "DeepCopyUniquePtr requires a Cloneable type");
66 : ptr(dynamic_cast<T*>(other.clone())) {}
69 : ptr(dynamic_cast<T*>(other.move_clone())) {}
71 : ptr(dynamic_cast<T*>(other.ptr->clone())) {}
73 : ptr(std::move(other.ptr)) {}
75 inline T& operator*() {
78 const inline T& operator*()
const {
82 return ptr.operator->();
84 const T* operator->()
const {
85 return ptr.operator->();
89 std::unique_ptr<T> ptr;
101 template<
typename T,
typename ...Args>
102 std::unique_ptr<T> make_unique(Args&& ...args) {
103 return std::unique_ptr<T>(
new T(std::forward<Args>(args)...));
119 template <
typename T>
120 inline T&& MoveUniquePtr(
const std::unique_ptr<T>& ptr) {
121 return std::move(*ptr.get());
135 : oss_(), text_width_(text_width), indent_(indent), delimiter_(delimiter),
136 default_precision_(oss_.precision()), line_length_(indent),
143 template <
typename T>
147 oss_ << std::string(indent_,
' ');
149 std::ostringstream tmp;
150 tmp << std::setprecision(GetPrecision<T>()) << e << delimiter_ <<
" ";
151 const std::string token = tmp.str();
152 if (line_length_ + token.length() <= text_width_) {
154 line_length_ += token.length();
156 oss_ <<
"\n" << std::string(indent_,
' ') << token;
157 line_length_ = token.length() + indent_;
166 inline std::string
str() {
171 std::ostringstream oss_;
172 const size_t indent_;
173 const size_t text_width_;
174 const char delimiter_;
175 const int default_precision_;
179 template <
typename T>
180 inline int GetPrecision() {
181 return default_precision_;
186 inline int ArrayFormatter::GetPrecision<float>() {
187 return std::numeric_limits<float>::digits10 + 2;
190 inline int ArrayFormatter::GetPrecision<double>() {
191 return std::numeric_limits<double>::digits10 + 2;
201 inline std::string IndentMultiLineString(
const std::string& str,
203 std::ostringstream oss;
204 if (str[0] !=
'\n') {
205 oss << std::string(indent,
' ');
207 bool need_indent =
false;
212 }
else if (need_indent) {
213 oss << std::string(indent,
' ');
230 template<
class Iter,
class T>
231 Iter binary_search(Iter begin, Iter end,
const T& val) {
232 Iter i = std::lower_bound(begin, end, val);
233 if (i != end && !(val < *i)) {
246 template <
typename T>
247 inline std::string ToStringHighPrecision(T value) {
248 return fmt::format(
"{:.{}g}", value, std::numeric_limits<T>::digits10 + 2);
258 inline void WriteToFile(
const std::string& filename,
259 const std::string& content) {
260 std::ofstream of(filename);
269 inline void WriteToFile(
const std::string& filename,
270 const std::vector<char>& content) {
271 std::ofstream of(filename, std::ios::out | std::ios::binary);
272 of.write(content.data(), content.size());
282 inline void TransformPushBack(std::vector<std::string>* p_dest,
283 const std::vector<std::string>& lines,
284 std::function<std::string(std::string)> func) {
285 auto& dest = *p_dest;
286 std::transform(lines.begin(), lines.end(), std::back_inserter(dest), func);
295 template <
typename T>
296 inline T TextToNumber(
const std::string& str) {
297 static_assert(std::is_same<T, float>::value
298 || std::is_same<T, double>::value
299 || std::is_same<T, int>::value
300 || std::is_same<T, int8_t>::value
301 || std::is_same<T, uint32_t>::value
302 || std::is_same<T, uint64_t>::value,
303 "unsupported data type for TextToNumber; use float, double, " 304 "int, int8_t, uint32_t, or uint64_t.");
308 inline float TextToNumber(
const std::string& str) {
311 float val = std::strtof(str.c_str(), &endptr);
312 if (errno == ERANGE) {
313 LOG(FATAL) <<
"Range error while converting string to double";
314 }
else if (errno != 0) {
315 LOG(FATAL) <<
"Unknown error";
316 }
else if (*endptr !=
'\0') {
317 LOG(FATAL) <<
"String does not represent a valid floating-point number";
323 inline double TextToNumber(
const std::string& str) {
326 double val = std::strtod(str.c_str(), &endptr);
327 if (errno == ERANGE) {
328 LOG(FATAL) <<
"Range error while converting string to double";
329 }
else if (errno != 0) {
330 LOG(FATAL) <<
"Unknown error";
331 }
else if (*endptr !=
'\0') {
332 LOG(FATAL) <<
"String does not represent a valid floating-point number";
338 inline int TextToNumber(
const std::string& str) {
341 auto val = std::strtol(str.c_str(), &endptr, 10);
342 if (errno == ERANGE || val < INT_MIN || val > INT_MAX) {
343 LOG(FATAL) <<
"Range error while converting string to int";
344 }
else if (errno != 0) {
345 LOG(FATAL) <<
"Unknown error";
346 }
else if (*endptr !=
'\0') {
347 LOG(FATAL) <<
"String does not represent a valid integer";
349 return static_cast<int>(val);
353 inline int8_t TextToNumber(
const std::string& str) {
356 auto val = std::strtol(str.c_str(), &endptr, 10);
357 if (errno == ERANGE || val < INT8_MIN || val > INT8_MAX) {
358 LOG(FATAL) <<
"Range error while converting string to int8_t";
359 }
else if (errno != 0) {
360 LOG(FATAL) <<
"Unknown error";
361 }
else if (*endptr !=
'\0') {
362 LOG(FATAL) <<
"String does not represent a valid integer";
364 return static_cast<int8_t
>(val);
368 inline uint32_t TextToNumber(
const std::string& str) {
371 auto val = std::strtoul(str.c_str(), &endptr, 10);
372 if (errno == ERANGE || val > UINT32_MAX) {
373 LOG(FATAL) <<
"Range error while converting string to uint32_t";
374 }
else if (errno != 0) {
375 LOG(FATAL) <<
"Unknown error";
376 }
else if (*endptr !=
'\0') {
377 LOG(FATAL) <<
"String does not represent a valid integer";
379 return static_cast<uint32_t
>(val);
383 inline uint64_t TextToNumber(
const std::string& str) {
386 auto val = std::strtoull(str.c_str(), &endptr, 10);
387 if (errno == ERANGE || val > std::numeric_limits<uint64_t>::max()) {
388 LOG(FATAL) <<
"Range error while converting string to size_t";
389 }
else if (errno != 0) {
390 LOG(FATAL) <<
"Unknown error";
391 }
else if (*endptr !=
'\0') {
392 LOG(FATAL) <<
"String does not represent a valid integer";
394 return static_cast<uint64_t
>(val);
404 template <
typename T>
405 inline std::vector<T> TextToArray(
const std::string& text,
int num_entry) {
406 if (text.empty() && num_entry > 0) {
407 LOG(FATAL) <<
"Cannot convert empty text into array";
409 std::vector<T> array;
410 std::istringstream ss(text);
412 for (
int i = 0; i < num_entry; ++i) {
413 std::getline(ss, token,
' ');
414 array.push_back(TextToNumber<T>(token));
425 inline std::vector<std::string> Split(
const std::string& text,
char delim) {
426 std::vector<std::string> array;
427 std::istringstream ss(text);
429 while (std::getline(ss, token, delim)) {
430 array.push_back(token);
446 case treelite::Operator::kEQ:
return lhs == rhs;
447 case treelite::Operator::kLT:
return lhs < rhs;
448 case treelite::Operator::kLE:
return lhs <= rhs;
449 case treelite::Operator::kGT:
return lhs > rhs;
450 case treelite::Operator::kGE:
return lhs >= rhs;
452 LOG(FATAL) <<
"operator undefined";
459 #endif // TREELITE_COMMON_H_
abstract interface for classes that can be cloned
a wrapper around std::unique_ptr that supports deep copying and moving.
double tl_float
float type to be used internally
defines configuration macros of treelite
Operator
comparison operators