7 #ifndef TREELITE_THREADING_UTILS_PARALLEL_FOR_H_ 8 #define TREELITE_THREADING_UTILS_PARALLEL_FOR_H_ 12 #include <type_traits> 20 namespace threading_utils {
28 std::exception_ptr omp_exception_;
36 template <
typename Function,
typename... Parameters>
37 void Run(Function f, Parameters... params) {
40 }
catch (std::exception& ex) {
41 std::lock_guard<std::mutex> lock(mutex_);
42 if (!omp_exception_) {
43 omp_exception_ = std::current_exception();
52 if (this->omp_exception_) {
53 std::rethrow_exception(this->omp_exception_);
58 inline int OmpGetThreadLimit() {
59 int limit = omp_get_thread_limit();
60 TREELITE_CHECK_GE(limit, 1) <<
"Invalid thread limit for OpenMP.";
64 inline int MaxNumThread() {
65 return std::min(std::min(omp_get_num_procs(), omp_get_max_threads()), OmpGetThreadLimit());
72 std::uint32_t nthread;
83 nthread = MaxNumThread();
84 TREELITE_CHECK_GE(nthread, 1) <<
"Invalid number of threads configured in OpenMP";
86 TREELITE_CHECK_LE(nthread, MaxNumThread())
87 <<
"nthread cannot exceed " << MaxNumThread() <<
" (configured by OpenMP).";
89 return ThreadConfig{
static_cast<std::uint32_t
>(nthread)};
100 std::size_t chunk{0};
108 template <
typename IndexType,
typename FuncType>
109 inline void ParallelFor(IndexType begin, IndexType end,
const ThreadConfig& thread_config,
115 #if defined(_MSC_VER) 117 using OmpInd = std::conditional_t<std::is_signed<IndexType>::value, IndexType, std::int64_t>;
119 using OmpInd = IndexType;
123 switch (sched.sched) {
124 case ParallelSchedule::kAuto: {
125 #pragma omp parallel for num_threads(thread_config.nthread) 126 for (OmpInd i = begin; i < end; ++i) {
127 exc.
Run(func, static_cast<IndexType>(i), omp_get_thread_num());
131 case ParallelSchedule::kDynamic: {
132 if (sched.chunk == 0) {
133 #pragma omp parallel for num_threads(thread_config.nthread) schedule(dynamic) 134 for (OmpInd i = begin; i < end; ++i) {
135 exc.
Run(func, static_cast<IndexType>(i), omp_get_thread_num());
138 #pragma omp parallel for num_threads(thread_config.nthread) schedule(dynamic, sched.chunk) 139 for (OmpInd i = begin; i < end; ++i) {
140 exc.
Run(func, static_cast<IndexType>(i), omp_get_thread_num());
145 case ParallelSchedule::kStatic: {
146 if (sched.chunk == 0) {
147 #pragma omp parallel for num_threads(thread_config.nthread) schedule(static) 148 for (OmpInd i = begin; i < end; ++i) {
149 exc.
Run(func, static_cast<IndexType>(i), omp_get_thread_num());
152 #pragma omp parallel for num_threads(thread_config.nthread) schedule(static, sched.chunk) 153 for (OmpInd i = begin; i < end; ++i) {
154 exc.
Run(func, static_cast<IndexType>(i), omp_get_thread_num());
159 case ParallelSchedule::kGuided: {
160 #pragma omp parallel for num_threads(thread_config.nthread) schedule(guided) 161 for (OmpInd i = begin; i < end; ++i) {
162 exc.
Run(func, static_cast<IndexType>(i), omp_get_thread_num());
173 #endif // TREELITE_THREADING_UTILS_PARALLEL_FOR_H_ void Rethrow()
should be called from the main thread to rethrow the exception
Represent thread configuration, to be used with parallel loops.
void Run(Function f, Parameters... params)
Parallel OMP blocks should be placed within Run to save exception.
logging facility for Treelite
ThreadConfig ConfigureThreadConfig(int nthread)
Create therad configuration.
OMP Exception class catches, saves and rethrows exception from OMP blocks.
compatiblity wrapper for systems that don't support OpenMP