treelite
threading_utils.h
Go to the documentation of this file.
1 
8 #ifndef TREELITE_DETAIL_THREADING_UTILS_H_
9 #define TREELITE_DETAIL_THREADING_UTILS_H_
10 
12 #include <treelite/logging.h>
13 
14 #include <algorithm>
15 #include <cstddef>
16 #include <cstdint>
17 #include <exception>
18 #include <limits>
19 #include <mutex>
20 #include <type_traits>
21 
22 #if TREELITE_OPENMP_SUPPORT
23 #include <omp.h>
24 #else
25 
26 // Stubs for OpenMP functions
27 
28 inline int omp_get_max_threads() {
29  return 1;
30 }
31 
32 inline int omp_get_num_procs() {
33  return 1;
34 }
35 
36 inline int omp_get_thread_num() {
37  return 0;
38 }
39 
40 inline int omp_get_thread_limit() {
41  return std::numeric_limits<int>::max();
42 }
43 
44 #endif // TREELITE_OPENMP_SUPPORT
45 
47 
48 inline int OmpGetThreadLimit() {
49 // MSVC doesn't implement the thread limit.
50 #if defined(_MSC_VER)
51  int limit = std::numeric_limits<int>::max();
52 #else
53  int limit = omp_get_thread_limit();
54 #endif
55  TREELITE_CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
56  return limit;
57 }
58 
59 inline int MaxNumThread() {
60  return std::min(std::min(omp_get_num_procs(), omp_get_max_threads()), OmpGetThreadLimit());
61 }
62 
66 struct ThreadConfig {
67  std::uint32_t nthread;
74  explicit ThreadConfig(int nthread) {
75  if (nthread <= 0) {
77  TREELITE_CHECK_GE(nthread, 1) << "Invalid number of threads configured in OpenMP";
78  } else {
80  << "nthread cannot exceed " << MaxNumThread() << " (configured by OpenMP).";
81  }
82  this->nthread = static_cast<std::uint32_t>(nthread);
83  }
84 };
85 
86 // OpenMP schedule
88  enum {
93  } sched{kStatic};
94  std::size_t chunk{0};
95 
97  return ParallelSchedule{kAuto};
98  }
99  ParallelSchedule static Dynamic(std::size_t n = 0) {
100  return ParallelSchedule{kDynamic, n};
101  }
102  ParallelSchedule static Static(std::size_t n = 0) {
103  return ParallelSchedule{kStatic, n};
104  }
106  return ParallelSchedule{kGuided};
107  }
108 };
109 
110 template <typename IndexType, typename FuncType>
111 inline void ParallelFor(IndexType begin, IndexType end, ThreadConfig const& thread_config,
112  ParallelSchedule sched, FuncType func) {
113  if (begin == end) {
114  return;
115  }
116 
117 #if defined(_MSC_VER)
118  // msvc doesn't support unsigned integer as openmp index.
119  using OmpInd = std::conditional_t<std::is_signed<IndexType>::value, IndexType, std::int64_t>;
120 #else
121  using OmpInd = IndexType;
122 #endif
123 
124  OMPException exc;
125  switch (sched.sched) {
127 #pragma omp parallel for num_threads(thread_config.nthread)
128  for (OmpInd i = begin; i < end; ++i) {
129  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
130  }
131  break;
132  }
134  if (sched.chunk == 0) {
135 #pragma omp parallel for num_threads(thread_config.nthread) schedule(dynamic)
136  for (OmpInd i = begin; i < end; ++i) {
137  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
138  }
139  } else {
140 #pragma omp parallel for num_threads(thread_config.nthread) schedule(dynamic, sched.chunk)
141  for (OmpInd i = begin; i < end; ++i) {
142  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
143  }
144  }
145  break;
146  }
148  if (sched.chunk == 0) {
149 #pragma omp parallel for num_threads(thread_config.nthread) schedule(static)
150  for (OmpInd i = begin; i < end; ++i) {
151  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
152  }
153  } else {
154 #pragma omp parallel for num_threads(thread_config.nthread) schedule(static, sched.chunk)
155  for (OmpInd i = begin; i < end; ++i) {
156  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
157  }
158  }
159  break;
160  }
162 #pragma omp parallel for num_threads(thread_config.nthread) schedule(guided)
163  for (OmpInd i = begin; i < end; ++i) {
164  exc.Run(func, static_cast<IndexType>(i), omp_get_thread_num());
165  }
166  break;
167  }
168  }
169  exc.Rethrow();
170 }
171 
172 } // namespace treelite::detail::threading_utils
173 
174 #endif // TREELITE_DETAIL_THREADING_UTILS_H_
OMP Exception class catches, saves and rethrows exception from OMP blocks.
Definition: omp_exception.h:20
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
Definition: omp_exception.h:32
void Rethrow()
should be called from the main thread to rethrow the exception
Definition: omp_exception.h:51
logging facility for Treelite
#define TREELITE_CHECK_LE(x, y)
Definition: logging.h:75
#define TREELITE_CHECK_GE(x, y)
Definition: logging.h:76
Definition: threading_utils.h:46
int OmpGetThreadLimit()
Definition: threading_utils.h:48
void ParallelFor(IndexType begin, IndexType end, ThreadConfig const &thread_config, ParallelSchedule sched, FuncType func)
Definition: threading_utils.h:111
int MaxNumThread()
Definition: threading_utils.h:59
Utility to propagate exceptions throws inside an OpenMP block.
@ kGuided
Definition: threading_utils.h:92
@ kAuto
Definition: threading_utils.h:89
@ kStatic
Definition: threading_utils.h:91
@ kDynamic
Definition: threading_utils.h:90
static ParallelSchedule Dynamic(std::size_t n=0)
Definition: threading_utils.h:99
static ParallelSchedule Static(std::size_t n=0)
Definition: threading_utils.h:102
enum treelite::detail::threading_utils::ParallelSchedule::@0 kStatic
static ParallelSchedule Auto()
Definition: threading_utils.h:96
static ParallelSchedule Guided()
Definition: threading_utils.h:105
std::size_t chunk
Definition: threading_utils.h:94
Represent thread configuration, to be used with parallel loops.
Definition: threading_utils.h:66
std::uint32_t nthread
Definition: threading_utils.h:67
ThreadConfig(int nthread)
Create thread configuration.
Definition: threading_utils.h:74
int omp_get_num_procs()
Definition: threading_utils.h:32
int omp_get_max_threads()
Definition: threading_utils.h:28
int omp_get_thread_limit()
Definition: threading_utils.h:40
int omp_get_thread_num()
Definition: threading_utils.h:36