Skip to content

Commit a83a0f0

Browse files
pitroukszucs
authored andcommitted
ARROW-6983: [C++] Fix ThreadedTaskGroup lifetime issue
Together with a test. Closes #5724 from pitrou/ARROW-6983-threaded-task-group-lifetime and squashes the following commits: 451c687 <Antoine Pitrou> ARROW-6983: Fix ThreadedTaskGroup lifetime issue Authored-by: Antoine Pitrou <[email protected]> Signed-off-by: Benjamin Kietzman <[email protected]>
1 parent 4142ed5 commit a83a0f0

File tree

3 files changed

+92
-5
lines changed

3 files changed

+92
-5
lines changed

cpp/src/arrow/util/task_group.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <mutex>
2424
#include <utility>
2525

26+
#include "arrow/util/checked_cast.h"
2627
#include "arrow/util/logging.h"
2728
#include "arrow/util/thread_pool.h"
2829

@@ -88,13 +89,15 @@ class ThreadedTaskGroup : public TaskGroup {
8889
// Only if an error occurs is the lock taken
8990
if (ok_.load(std::memory_order_acquire)) {
9091
nremaining_.fetch_add(1, std::memory_order_acquire);
91-
Status st = thread_pool_->Spawn([this, task]() {
92-
if (ok_.load(std::memory_order_acquire)) {
92+
93+
auto self = checked_pointer_cast<ThreadedTaskGroup>(shared_from_this());
94+
Status st = thread_pool_->Spawn([self, task]() {
95+
if (self->ok_.load(std::memory_order_acquire)) {
9396
// XXX what about exceptions?
9497
Status st = task();
95-
UpdateStatus(std::move(st));
98+
self->UpdateStatus(std::move(st));
9699
}
97-
OneTaskDone();
100+
self->OneTaskDone();
98101
});
99102
UpdateStatus(std::move(st));
100103
}

cpp/src/arrow/util/task_group.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ThreadPool;
4040
/// implementation. When Finish() returns, it is guaranteed that all
4141
/// tasks have finished, or at least one has errored.
4242
///
43-
class ARROW_EXPORT TaskGroup {
43+
class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this<TaskGroup> {
4444
public:
4545
/// Add a Status-returning function to execute. Execution order is
4646
/// undefined. The function may be executed immediately or later.

cpp/src/arrow/util/task_group_test.cc

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,75 @@ void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) {
212212
ASSERT_EQ(count.load(), (1 << (N + 1)) - 1);
213213
}
214214

215+
// A task that keeps recursing until a barrier is set.
216+
// Using a lambda for this doesn't play well with Thread Sanitizer.
217+
struct BarrierTask {
218+
std::atomic<bool>* barrier_;
219+
std::weak_ptr<TaskGroup> weak_group_ptr_;
220+
Status final_status_;
221+
222+
Status operator()() {
223+
if (!barrier_->load()) {
224+
sleep_for(1e-5);
225+
// Note the TaskGroup should be kept alive by the fact this task
226+
// is still running...
227+
weak_group_ptr_.lock()->Append(*this);
228+
}
229+
return final_status_;
230+
}
231+
};
232+
233+
// Try to replicate subtle lifetime issues when destroying a TaskGroup
234+
// where all tasks may not have finished running.
235+
void StressTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
236+
const int NTASKS = 100;
237+
auto task_group = factory();
238+
auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
239+
240+
std::atomic<bool> barrier(false);
241+
242+
BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
243+
244+
for (int i = 0; i < NTASKS; ++i) {
245+
task_group->Append(task);
246+
}
247+
248+
// Lose strong reference
249+
barrier.store(true);
250+
task_group.reset();
251+
252+
// Wait for finish
253+
while (!weak_group_ptr.expired()) {
254+
sleep_for(1e-5);
255+
}
256+
}
257+
258+
// Same, but with also a failing task
259+
void StressFailingTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
260+
const int NTASKS = 100;
261+
auto task_group = factory();
262+
auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
263+
264+
std::atomic<bool> barrier(false);
265+
266+
BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
267+
BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")};
268+
269+
for (int i = 0; i < NTASKS; ++i) {
270+
task_group->Append(task);
271+
}
272+
task_group->Append(failing_task);
273+
274+
// Lose strong reference
275+
barrier.store(true);
276+
task_group.reset();
277+
278+
// Wait for finish
279+
while (!weak_group_ptr.expired()) {
280+
sleep_for(1e-5);
281+
}
282+
}
283+
215284
TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); }
216285

217286
TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
@@ -259,5 +328,20 @@ TEST(ThreadedTaskGroup, SubGroupsErrors) {
259328
TestTaskSubGroupsErrors(TaskGroup::MakeThreaded(thread_pool.get()));
260329
}
261330

331+
TEST(ThreadedTaskGroup, StressTaskGroupLifetime) {
332+
std::shared_ptr<ThreadPool> thread_pool;
333+
ASSERT_OK(ThreadPool::Make(16, &thread_pool));
334+
335+
StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
336+
}
337+
338+
TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
339+
std::shared_ptr<ThreadPool> thread_pool;
340+
ASSERT_OK(ThreadPool::Make(16, &thread_pool));
341+
342+
StressFailingTaskGroupLifetime(
343+
[&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
344+
}
345+
262346
} // namespace internal
263347
} // namespace arrow

0 commit comments

Comments
 (0)