Skip to content

Commit 331c6cb

Browse files
committed
Round Robin Queue Servicing support
Changes from inkooboo#24 are incorporated. Conditional variables instead spin-lock are kept as well.
1 parent ce2a969 commit 331c6cb

File tree

2 files changed

+58
-23
lines changed

2 files changed

+58
-23
lines changed

include/thread_pool/thread_pool.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ static bool v_affinity = false; /* Default: disabled */
3232

3333
template <typename Task, template<typename> class Queue>
3434
class ThreadPoolImpl;
35+
3536
using ThreadPool = ThreadPoolImpl<FixedFunction<void(), 128>,
3637
MPMCBoundedQueue>;
3738

@@ -45,6 +46,9 @@ using ThreadPool = ThreadPoolImpl<FixedFunction<void(), 128>,
4546
*/
4647
template <typename Task, template<typename> class Queue>
4748
class ThreadPoolImpl {
49+
50+
using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;
51+
4852
public:
4953
/**
5054
* @brief ThreadPool Construct and start new thread pool.
@@ -90,7 +94,7 @@ class ThreadPoolImpl {
9094
private:
9195
Worker<Task, Queue>& getWorker();
9296

93-
std::vector<std::unique_ptr<Worker<Task, Queue>>> m_workers;
97+
WorkerVector m_workers;
9498
std::atomic<std::size_t> m_next_worker;
9599

96100
#if defined __sun__ || defined __linux__ || defined __FreeBSD__
@@ -129,9 +133,6 @@ inline ThreadPoolImpl<Task, Queue>::ThreadPoolImpl(
129133

130134
for(std::size_t i = 0; i < m_workers.size(); ++i)
131135
{
132-
Worker<Task, Queue>* steal_donor =
133-
m_workers[(i + 1) % m_workers.size()].get();
134-
135136
#if defined __sun__ || defined __linux__ || defined __FreeBSD__
136137
if (v_affinity) {
137138
if (v_cpu > v_cpu_max)
@@ -160,7 +161,7 @@ inline ThreadPoolImpl<Task, Queue>::ThreadPoolImpl(
160161
}
161162
#endif
162163

163-
m_workers[i]->start(i, steal_donor);
164+
m_workers[i]->start(i, &m_workers);
164165
}
165166
}
166167

@@ -195,7 +196,7 @@ template <typename Task, template<typename> class Queue>
195196
template <typename Handler>
196197
inline bool ThreadPoolImpl<Task, Queue>::tryPost(Handler&& handler)
197198
{
198-
return getWorker().post(std::forward<Handler>(handler));
199+
return getWorker().tryPost(std::forward<Handler>(handler));
199200
}
200201

201202
template <typename Task, template<typename> class Queue>

include/thread_pool/worker.hpp

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <thread>
55
#include <condition_variable>
66
#include <mutex>
7+
#include <limits>
78

89
namespace tp
910
{
@@ -17,6 +18,8 @@ namespace tp
1718
template <typename Task, template<typename> class Queue>
1819
class Worker
1920
{
21+
using WorkerVector = std::vector<std::unique_ptr<Worker<Task, Queue>>>;
22+
2023
public:
2124
/**
2225
* @brief Worker Constructor.
@@ -37,9 +40,9 @@ class Worker
3740
/**
3841
* @brief start Create the executing thread and start tasks execution.
3942
* @param id Worker ID.
40-
* @param steal_donor Sibling worker to steal task from it.
43+
* @param workers Sibling workers for performing round robin work stealing.
4144
*/
42-
void start(std::size_t id, Worker* steal_donor);
45+
void start(std::size_t id, WorkerVector* workers);
4346

4447
/**
4548
* @brief stop Stop all worker's thread and stealing activity.
@@ -48,19 +51,19 @@ class Worker
4851
void stop();
4952

5053
/**
51-
* @brief post Post task to queue.
54+
* @brief tryPost Post task to queue.
5255
* @param handler Handler to be executed in executing thread.
5356
* @return true on success.
5457
*/
5558
template <typename Handler>
56-
bool post(Handler&& handler);
59+
bool tryPost(Handler&& handler);
5760

5861
/**
59-
* @brief steal Steal one task from this worker queue.
60-
* @param task Place for stealed task to be stored.
62+
* @brief tryGetLocalTask Get one task from this worker queue.
63+
* @param task Place for the obtained task to be stored.
6164
* @return true on success.
6265
*/
63-
bool steal(Task& task);
66+
bool tryGetLocalTask(Task& task);
6467

6568
/**
6669
* @brief getWorkerIdForCurrentThread Return worker ID associated with
@@ -70,16 +73,24 @@ class Worker
7073
static std::size_t getWorkerIdForCurrentThread();
7174

7275
private:
76+
/**
77+
* @brief tryRoundRobinSteal Try stealing a thread from sibling workers in a round-robin fashion.
78+
* @param task Place for the obtained task to be stored.
79+
* @param workers Sibling workers for performing round robin work stealing.
80+
*/
81+
bool tryRoundRobinSteal(Task& task, WorkerVector* workers);
82+
7383
/**
7484
* @brief threadFunc Executing thread function.
7585
* @param id Worker ID to be associated with this thread.
76-
* @param steal_donor Sibling worker to steal task from it.
86+
* @param workers Sibling workers for performing round robin work stealing.
7787
*/
78-
void threadFunc(std::size_t id, Worker* steal_donor);
88+
void threadFunc(size_t id, WorkerVector* workers);
7989

8090
Queue<Task> m_queue;
8191
std::atomic<bool> m_running_flag;
8292
std::thread m_thread;
93+
std::size_t m_next_donor;
8394
std::mutex m_conditional_mutex;
8495
std::condition_variable m_conditional_lock;
8596
};
@@ -91,7 +102,7 @@ namespace detail
91102
{
92103
inline std::size_t* thread_id()
93104
{
94-
static thread_local std::size_t tss_id = -1u;
105+
static thread_local std::size_t tss_id = std::numeric_limits<std::size_t>::max();
95106
return &tss_id;
96107
}
97108
}
@@ -100,6 +111,7 @@ template <typename Task, template<typename> class Queue>
100111
inline Worker<Task, Queue>::Worker(std::size_t queue_size)
101112
: m_queue(queue_size)
102113
, m_running_flag(true)
114+
, m_next_donor(0) // Initialized in threadFunc.
103115
{
104116
}
105117

@@ -132,9 +144,9 @@ inline void Worker<Task, Queue>::stop()
132144
}
133145

134146
template <typename Task, template<typename> class Queue>
135-
inline void Worker<Task, Queue>::start(std::size_t id, Worker* steal_donor)
147+
inline void Worker<Task, Queue>::start(std::size_t id, WorkerVector* workers)
136148
{
137-
m_thread = std::thread(&Worker<Task, Queue>::threadFunc, this, id, steal_donor);
149+
m_thread = std::thread(&Worker<Task, Queue>::threadFunc, this, id, workers);
138150
}
139151

140152
template <typename Task, template<typename> class Queue>
@@ -145,36 +157,58 @@ inline std::size_t Worker<Task, Queue>::getWorkerIdForCurrentThread()
145157

146158
template <typename Task, template<typename> class Queue>
147159
template <typename Handler>
148-
inline bool Worker<Task, Queue>::post(Handler&& handler)
160+
inline bool Worker<Task, Queue>::tryPost(Handler&& handler)
149161
{
150162
m_conditional_lock.notify_all();
151163
return m_queue.push(std::forward<Handler>(handler));
152164
}
153165

154166
template <typename Task, template<typename> class Queue>
155-
inline bool Worker<Task, Queue>::steal(Task& task)
167+
inline bool Worker<Task, Queue>::tryGetLocalTask(Task& task)
156168
{
157169
return m_queue.pop(task);
158170
}
159171

160172
template <typename Task, template<typename> class Queue>
161-
inline void Worker<Task, Queue>::threadFunc(std::size_t id, Worker* steal_donor)
173+
inline bool Worker<Task, Queue>::tryRoundRobinSteal(Task& task, WorkerVector* workers)
174+
{
175+
auto starting_index = m_next_donor;
176+
// Iterate once through the worker ring, checking for queued work items on each thread.
177+
do
178+
{
179+
// Don't steal from local queue.
180+
if (m_next_donor != *detail::thread_id() && workers->at(m_next_donor)->tryGetLocalTask(task))
181+
{
182+
// Increment before returning so that m_next_donor always points to the worker that has gone the longest
183+
// without a steal attempt. This helps enforce fairness in the stealing.
184+
++m_next_donor %= workers->size();
185+
return true;
186+
}
187+
++m_next_donor %= workers->size();
188+
} while (m_next_donor != starting_index);
189+
return false;
190+
}
191+
192+
template <typename Task, template<typename> class Queue>
193+
inline void Worker<Task, Queue>::threadFunc(size_t id, WorkerVector* workers)
162194
{
163195
*detail::thread_id() = id;
196+
m_next_donor = ++id % workers->size();
164197

165198
Task handler;
166199

167200
while (m_running_flag.load(std::memory_order_relaxed))
168201
{
169-
if (m_queue.pop(handler) || steal_donor->steal(handler))
202+
// Prioritize local queue, then try stealing from sibling workers.
203+
if (tryGetLocalTask(handler) || tryRoundRobinSteal(handler, workers))
170204
{
171205
try
172206
{
173207
handler();
174208
}
175209
catch(...)
176210
{
177-
// suppress all exceptions
211+
// Suppress all exceptions.
178212
}
179213
}
180214
else

0 commit comments

Comments
 (0)