Treelite
thread_pool.h
Go to the documentation of this file.
1 
7 #ifndef TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_
8 #define TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_
9 
10 #include <string>
11 #include <memory>
12 #include <vector>
13 #include <cstdlib>
14 #ifdef _WIN32
15 #include <windows.h>
16 #else
17 #include <sched.h>
18 #endif
19 
20 #if defined(__APPLE__) && defined(__MACH__)
21 #include <mach/mach.h>
22 #include <mach/mach_init.h>
23 #include <mach/thread_policy.h>
24 #endif
25 #include "spsc_queue.h"
26 
27 namespace treelite {
28 namespace predictor {
29 
30 template <typename InputToken, typename OutputToken, typename TaskContext>
31 class ThreadPool {
32  public:
33  using TaskFunc = void(*)(SpscQueue<InputToken>*, SpscQueue<OutputToken>*,
34  const TaskContext*);
35 
36  ThreadPool(int num_worker, const TaskContext* context, TaskFunc task)
37  : num_worker_(num_worker), context_(context), task_(task) {
38  TREELITE_CHECK(num_worker_ >= 0
39  && static_cast<unsigned>(num_worker_) < std::thread::hardware_concurrency())
40  << "Number of worker threads must be between 0 and "
41  << (std::thread::hardware_concurrency() - 1);
42  for (int i = 0; i < num_worker_; ++i) {
43  incoming_queue_.emplace_back(new SpscQueue<InputToken>());
44  outgoing_queue_.emplace_back(new SpscQueue<OutputToken>());
45  }
46  thread_.resize(num_worker_);
47  for (int i = 0; i < num_worker_; ++i) {
48  thread_[i] = std::thread(task_, incoming_queue_[i].get(),
49  outgoing_queue_[i].get(),
50  context_);
51  }
52  /* bind threads to cores */
53  const char* bind_flag = getenv("TREELITE_BIND_THREADS");
54  if (bind_flag == nullptr || std::stoi(bind_flag) == 1) {
55  SetAffinity();
56  }
57  }
58  ~ThreadPool() {
59  for (int i = 0; i < num_worker_; ++i) {
60  incoming_queue_[i]->SignalForKill();
61  outgoing_queue_[i]->SignalForKill();
62  thread_[i].join();
63  }
64  }
65 
66  void SubmitTask(int tid, InputToken request) {
67  incoming_queue_[tid]->Push(request);
68  }
69 
70  bool WaitForTask(int tid, OutputToken* response) {
71  return outgoing_queue_[tid]->Pop(response);
72  }
73 
74  private:
75  int num_worker_;
76  std::vector<std::thread> thread_;
77  std::vector<std::unique_ptr<SpscQueue<InputToken>>> incoming_queue_;
78  std::vector<std::unique_ptr<SpscQueue<OutputToken>>> outgoing_queue_;
79  const TaskContext* context_;
80  TaskFunc task_;
81 
82  inline void SetAffinity() {
83 #ifdef _WIN32
84  /* Windows */
85  SetThreadAffinityMask(GetCurrentThread(), 0x1);
86  for (int i = 0; i < num_worker_; ++i) {
87  const int core_id = i + 1;
88  SetThreadAffinityMask(thread_[i].native_handle(), (1ULL << core_id));
89  }
90 #elif defined(__APPLE__) && defined(__MACH__)
91 #include <TargetConditionals.h>
92 #if TARGET_OS_MAC == 1
93  /* Mac OSX */
94  thread_port_t mach_thread = pthread_mach_thread_np(pthread_self());
95  thread_affinity_policy_data_t policy = {0};
96  thread_policy_set(mach_thread, THREAD_AFFINITY_POLICY,
97  (thread_policy_t)&policy, THREAD_AFFINITY_POLICY_COUNT);
98  for (int i = 0; i < num_worker_; ++i) {
99  const int core_id = i + 1;
100  mach_thread = pthread_mach_thread_np(thread_[i].native_handle());
101  policy = {core_id};
102  thread_policy_set(mach_thread, THREAD_AFFINITY_POLICY,
103  (thread_policy_t)&policy, THREAD_AFFINITY_POLICY_COUNT);
104  }
105 #else
106  #error "iPhone not supported yet"
107 #endif
108 #else
109  /* Linux and others */
110  cpu_set_t cpuset;
111  CPU_ZERO(&cpuset);
112  CPU_SET(0, &cpuset);
113 #if defined(__ANDROID__)
114  sched_setaffinity(pthread_self(), sizeof(cpu_set_t), &cpuset);
115 #else
116  pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
117 #endif
118  for (int i = 0; i < num_worker_; ++i) {
119  const int core_id = i + 1;
120  CPU_ZERO(&cpuset);
121  CPU_SET(core_id, &cpuset);
122 #if defined(__ANDROID__)
123  sched_setaffinity(thread_[i].native_handle(),
124  sizeof(cpu_set_t), &cpuset);
125 #else
126  pthread_setaffinity_np(thread_[i].native_handle(),
127  sizeof(cpu_set_t), &cpuset);
128 #endif
129  }
130 #endif
131  }
132 };
133 
134 } // namespace predictor
135 } // namespace treelite
136 
137 #endif // TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_
Lock-free single-producer-single-consumer queue.