Treelite
parallel_for.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
8 #define TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
9 
10 #include <treelite/logging.h>
11 #include <future>
12 #include <thread>
13 #include <algorithm>
14 #include <vector>
15 #include <cstddef>
16 
17 namespace treelite {
18 namespace threading_utils {
19 
20 template <typename IndexType>
21 std::vector<IndexType> ComputeWorkRange(IndexType begin, IndexType end, std::size_t nthread);
22 
23 template <typename IndexType, typename FuncType>
24 void ParallelFor(IndexType begin, IndexType end, std::size_t nthread, FuncType func) {
25  TREELITE_CHECK_GT(nthread, 0) << "nthread must be positive";
26  TREELITE_CHECK_LE(nthread, std::thread::hardware_concurrency())
27  << "nthread cannot exceed " << std::thread::hardware_concurrency();
28  if (begin == end) {
29  return;
30  }
31  /* Divide the range [begin, end) equally among the threads.
32  * The i-th thread gets the range [work_range[i], work_range[i+1]). */
33  std::vector<IndexType> work_range = ComputeWorkRange(begin, end, nthread);
34 
35  // Launch (nthread - 1) threads, as the main thread should also perform work.
36  std::vector<std::future<void>> async_tasks;
37  for (std::size_t thread_id = 1; thread_id < nthread; ++thread_id) {
38  async_tasks.push_back(std::async(std::launch::async, [&work_range, &func, thread_id]() {
39  const IndexType begin_ = work_range[thread_id];
40  const IndexType end_ = work_range[thread_id + 1];
41  for (IndexType i = begin_; i < end_; ++i) {
42  func(i, thread_id);
43  }
44  }));
45  }
46  {
47  const IndexType begin_ = work_range[0];
48  const IndexType end_ = work_range[1];
49  for (IndexType i = begin_; i < end_; ++i) {
50  func(i, 0);
51  }
52  }
53  // Join threads
54  for (auto& task : async_tasks) {
55  task.get();
56  }
57 }
58 
59 template <typename IndexType>
60 std::vector<IndexType> ComputeWorkRange(IndexType begin, IndexType end, std::size_t nthread) {
61  TREELITE_CHECK_GE(end, 0) << "end must be 0 or greater";
62  TREELITE_CHECK_GE(begin, 0) << "begin must be 0 or greater";
63  TREELITE_CHECK_GE(end, begin) << "end cannot be less than begin";
64  TREELITE_CHECK_GT(nthread, 0) << "nthread must be positive";
65  IndexType num_elem = end - begin;
66  const IndexType portion = num_elem / nthread + !!(num_elem % nthread);
67  // integer division, rounded-up
68 
69  std::vector<IndexType> work_range(nthread + 1);
70  work_range[0] = begin;
71  IndexType acc = begin;
72  for (std::size_t i = 0; i < nthread; ++i) {
73  acc += portion;
74  work_range[i + 1] = std::min(acc, end);
75  }
76  TREELITE_CHECK_EQ(work_range[nthread], end);
77 
78  return work_range;
79 }
80 
81 } // namespace threading_utils
82 } // namespace treelite
83 
84 #endif // TREELITE_THREADING_UTILS_PARALLEL_FOR_H_
logging facility for Treelite