Treelite
parallel_for.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
8 #define TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
9 
10 #include <treelite/omp.h>
11 #include <treelite/logging.h>
12 #include <type_traits>
13 #include <algorithm>
14 #include <exception>
15 #include <mutex>
16 #include <cstddef>
17 #include <cstdint>
18 
19 namespace treelite {
20 namespace threading_utils {
21 
25 class OMPException {
26  private:
27  // exception_ptr member to store the exception
28  std::exception_ptr omp_exception_;
29  // mutex to be acquired during catch to set the exception_ptr
30  std::mutex mutex_;
31 
32  public:
36  template <typename Function, typename... Parameters>
37  void Run(Function f, Parameters... params) {
38  try {
39  f(params...);
40  } catch (std::exception& ex) {
41  std::lock_guard<std::mutex> lock(mutex_);
42  if (!omp_exception_) {
43  omp_exception_ = std::current_exception();
44  }
45  }
46  }
47 
51  void Rethrow() {
52  if (this->omp_exception_) {
53  std::rethrow_exception(this->omp_exception_);
54  }
55  }
56 };
57 
58 inline int OmpGetThreadLimit() {
59  int limit = omp_get_thread_limit();
60  TREELITE_CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
61  return limit;
62 }
63 
64 inline int MaxNumThread() {
65  return std::min(std::min(omp_get_num_procs(), omp_get_max_threads()), OmpGetThreadLimit());
66 }
67 
71 struct ThreadConfig {
72  std::uint32_t nthread;
73 };
74 
81 inline ThreadConfig ConfigureThreadConfig(int nthread) {
82  if (nthread <= 0) {
83  nthread = MaxNumThread();
84  TREELITE_CHECK_GE(nthread, 1) << "Invalid number of threads configured in OpenMP";
85  } else {
86  TREELITE_CHECK_LE(nthread, MaxNumThread())
87  << "nthread cannot exceed " << MaxNumThread() << " (configured by OpenMP).";
88  }
89  return ThreadConfig{static_cast<std::uint32_t>(nthread)};
90 }
91 
92 // OpenMP schedule
94  enum {
95  kAuto,
96  kDynamic,
97  kStatic,
98  kGuided,
99  } sched;
100  std::size_t chunk{0};
101 
102  ParallelSchedule static Auto() { return ParallelSchedule{kAuto}; }
103  ParallelSchedule static Dynamic(std::size_t n = 0) { return ParallelSchedule{kDynamic, n}; }
104  ParallelSchedule static Static(std::size_t n = 0) { return ParallelSchedule{kStatic, n}; }
105  ParallelSchedule static Guided() { return ParallelSchedule{kGuided}; }
106 };
107 
108 template <typename IndexType, typename FuncType>
109 inline void ParallelFor(IndexType begin, IndexType end, const ThreadConfig& thread_config,
110  ParallelSchedule sched, FuncType func) {
111  if (begin == end) {
112  return;
113  }
114 
115 #if defined(_MSC_VER)
116  // msvc doesn't support unsigned integer as openmp index.
117  using OmpInd = std::conditional_t<std::is_signed<IndexType>::value, IndexType, std::int64_t>;
118 #else
119  using OmpInd = IndexType;
120 #endif
121 
122  OMPException exc;
123  switch (sched.sched) {
124  case ParallelSchedule::kAuto: {
125 #pragma omp parallel for num_threads(thread_config.nthread)
126  for (OmpInd i = begin; i < end; ++i) {
127  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
128  }
129  break;
130  }
131  case ParallelSchedule::kDynamic: {
132  if (sched.chunk == 0) {
133 #pragma omp parallel for num_threads(thread_config.nthread) schedule(dynamic)
134  for (OmpInd i = begin; i < end; ++i) {
135  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
136  }
137  } else {
138 #pragma omp parallel for num_threads(thread_config.nthread) schedule(dynamic, sched.chunk)
139  for (OmpInd i = begin; i < end; ++i) {
140  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
141  }
142  }
143  break;
144  }
145  case ParallelSchedule::kStatic: {
146  if (sched.chunk == 0) {
147 #pragma omp parallel for num_threads(thread_config.nthread) schedule(static)
148  for (OmpInd i = begin; i < end; ++i) {
149  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
150  }
151  } else {
152 #pragma omp parallel for num_threads(thread_config.nthread) schedule(static, sched.chunk)
153  for (OmpInd i = begin; i < end; ++i) {
154  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
155  }
156  }
157  break;
158  }
159  case ParallelSchedule::kGuided: {
160 #pragma omp parallel for num_threads(thread_config.nthread) schedule(guided)
161  for (OmpInd i = begin; i < end; ++i) {
162  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
163  }
164  break;
165  }
166  }
167  exc.Rethrow();
168 }
169 
170 } // namespace threading_utils
171 } // namespace treelite
172 
173 #endif // TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
void Rethrow()
should be called from the main thread to rethrow the exception
Definition: parallel_for.h:51
Represent thread configuration, to be used with parallel loops.
Definition: parallel_for.h:71
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition: parallel_for.h:37
logging facility for Treelite
ThreadConfig ConfigureThreadConfig(int nthread)
Create therad configuration.
Definition: parallel_for.h:81
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition: parallel_for.h:25
compatiblity wrapper for systems that don&#39;t support OpenMP