7 #ifndef TREELITE_THREADING_UTILS_PARALLEL_FOR_H_ 8 #define TREELITE_THREADING_UTILS_PARALLEL_FOR_H_ 18 namespace threading_utils {
20 template <
typename IndexType>
21 std::vector<IndexType> ComputeWorkRange(IndexType begin, IndexType end, std::size_t nthread);
23 template <
typename IndexType,
typename FuncType>
24 void ParallelFor(IndexType begin, IndexType end, std::size_t nthread, FuncType func) {
25 TREELITE_CHECK_GT(nthread, 0) <<
"nthread must be positive";
26 TREELITE_CHECK_LE(nthread, std::thread::hardware_concurrency())
27 <<
"nthread cannot exceed " << std::thread::hardware_concurrency();
33 std::vector<IndexType> work_range = ComputeWorkRange(begin, end, nthread);
36 std::vector<std::future<void>> async_tasks;
37 for (std::size_t thread_id = 1; thread_id < nthread; ++thread_id) {
38 async_tasks.push_back(std::async(std::launch::async, [&work_range, &func, thread_id]() {
39 const IndexType begin_ = work_range[thread_id];
40 const IndexType end_ = work_range[thread_id + 1];
41 for (IndexType i = begin_; i < end_; ++i) {
47 const IndexType begin_ = work_range[0];
48 const IndexType end_ = work_range[1];
49 for (IndexType i = begin_; i < end_; ++i) {
54 for (
auto& task : async_tasks) {
59 template <
typename IndexType>
60 std::vector<IndexType> ComputeWorkRange(IndexType begin, IndexType end, std::size_t nthread) {
61 TREELITE_CHECK_GE(end, 0) <<
"end must be 0 or greater";
62 TREELITE_CHECK_GE(begin, 0) <<
"begin must be 0 or greater";
63 TREELITE_CHECK_GE(end, begin) <<
"end cannot be less than begin";
64 TREELITE_CHECK_GT(nthread, 0) <<
"nthread must be positive";
65 IndexType num_elem = end - begin;
66 const IndexType portion = num_elem / nthread + !!(num_elem % nthread);
69 std::vector<IndexType> work_range(nthread + 1);
70 work_range[0] = begin;
71 IndexType acc = begin;
72 for (std::size_t i = 0; i < nthread; ++i) {
74 work_range[i + 1] = std::min(acc, end);
76 TREELITE_CHECK_EQ(work_range[nthread], end);
84 #endif // TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
logging facility for Treelite