@@ -212,6 +212,75 @@ void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) {
212212 ASSERT_EQ (count.load (), (1 << (N + 1 )) - 1 );
213213}
214214
215+ // A task that keeps recursing until a barrier is set.
216+ // Using a lambda for this doesn't play well with Thread Sanitizer.
217+ struct BarrierTask {
218+ std::atomic<bool >* barrier_;
219+ std::weak_ptr<TaskGroup> weak_group_ptr_;
220+ Status final_status_;
221+
222+ Status operator ()() {
223+ if (!barrier_->load ()) {
224+ sleep_for (1e-5 );
225+ // Note the TaskGroup should be kept alive by the fact this task
226+ // is still running...
227+ weak_group_ptr_.lock ()->Append (*this );
228+ }
229+ return final_status_;
230+ }
231+ };
232+
233+ // Try to replicate subtle lifetime issues when destroying a TaskGroup
234+ // where all tasks may not have finished running.
235+ void StressTaskGroupLifetime (std::function<std::shared_ptr<TaskGroup>()> factory) {
236+ const int NTASKS = 100 ;
237+ auto task_group = factory ();
238+ auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
239+
240+ std::atomic<bool > barrier (false );
241+
242+ BarrierTask task{&barrier, weak_group_ptr, Status::OK ()};
243+
244+ for (int i = 0 ; i < NTASKS; ++i) {
245+ task_group->Append (task);
246+ }
247+
248+ // Lose strong reference
249+ barrier.store (true );
250+ task_group.reset ();
251+
252+ // Wait for finish
253+ while (!weak_group_ptr.expired ()) {
254+ sleep_for (1e-5 );
255+ }
256+ }
257+
258+ // Same, but with also a failing task
259+ void StressFailingTaskGroupLifetime (std::function<std::shared_ptr<TaskGroup>()> factory) {
260+ const int NTASKS = 100 ;
261+ auto task_group = factory ();
262+ auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);
263+
264+ std::atomic<bool > barrier (false );
265+
266+ BarrierTask task{&barrier, weak_group_ptr, Status::OK ()};
267+ BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid (" XXX" )};
268+
269+ for (int i = 0 ; i < NTASKS; ++i) {
270+ task_group->Append (task);
271+ }
272+ task_group->Append (failing_task);
273+
274+ // Lose strong reference
275+ barrier.store (true );
276+ task_group.reset ();
277+
278+ // Wait for finish
279+ while (!weak_group_ptr.expired ()) {
280+ sleep_for (1e-5 );
281+ }
282+ }
283+
215284TEST (SerialTaskGroup, Success) { TestTaskGroupSuccess (TaskGroup::MakeSerial ()); }
216285
217286TEST (SerialTaskGroup, Errors) { TestTaskGroupErrors (TaskGroup::MakeSerial ()); }
@@ -259,5 +328,20 @@ TEST(ThreadedTaskGroup, SubGroupsErrors) {
259328 TestTaskSubGroupsErrors (TaskGroup::MakeThreaded (thread_pool.get ()));
260329}
261330
331+ TEST (ThreadedTaskGroup, StressTaskGroupLifetime) {
332+ std::shared_ptr<ThreadPool> thread_pool;
333+ ASSERT_OK (ThreadPool::Make (16 , &thread_pool));
334+
335+ StressTaskGroupLifetime ([&] { return TaskGroup::MakeThreaded (thread_pool.get ()); });
336+ }
337+
338+ TEST (ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
339+ std::shared_ptr<ThreadPool> thread_pool;
340+ ASSERT_OK (ThreadPool::Make (16 , &thread_pool));
341+
342+ StressFailingTaskGroupLifetime (
343+ [&] { return TaskGroup::MakeThreaded (thread_pool.get ()); });
344+ }
345+
262346} // namespace internal
263347} // namespace arrow
0 commit comments