From b42c7fa172470a4793e15b26e6288d882633dd04 Mon Sep 17 00:00:00 2001 From: Mike Ash Date: Wed, 28 May 2025 15:19:47 -0400 Subject: [PATCH] [6.2][Concurrency] Fix races/overflows in TaskGroup implementation. statusCompletePendingReadyWaiting(), offer(), and poll() did a one-off compare_exchange_strong which could fail if the group was concurrently cancelled. Put these into loops so that they are retried when needed. DiscardingTaskGroup creation passed the group result type as the task result type. waitAll() would then use the group result type when collecting task results. Since the task result type is always Void in this case, this would overflow the result buffer if the group result type was larger. This often works as it writes into the free space of the task allocator, but can crash if it happens to be at the end of a page or the group result type is particularly large. rdar://151663730 (cherry picked from commit 5be22fa7cc395ad644057bd6ba69f9029a05ae27) --- .../Concurrency/DiscardingTaskGroup.swift | 8 +- stdlib/public/Concurrency/TaskGroup.cpp | 220 ++++++++++-------- .../async_taskgroup_cancellation_race.swift | 59 +++++ ...group_discarding_neverConsumingTasks.swift | 15 ++ 4 files changed, 195 insertions(+), 107 deletions(-) create mode 100644 test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift diff --git a/stdlib/public/Concurrency/DiscardingTaskGroup.swift b/stdlib/public/Concurrency/DiscardingTaskGroup.swift index 1de2e5311abcf..0fcec55446e16 100644 --- a/stdlib/public/Concurrency/DiscardingTaskGroup.swift +++ b/stdlib/public/Concurrency/DiscardingTaskGroup.swift @@ -80,7 +80,7 @@ public func withDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = DiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } @@ -108,7 +108,7 @@ public func _unsafeInheritExecutor_withDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = DiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } @@ -347,7 +347,7 @@ public func withThrowingDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = ThrowingDiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } @@ -378,7 +378,7 @@ public func _unsafeInheritExecutor_withThrowingDiscardingTaskGroup( discardResults: true ) - let _group = Builtin.createTaskGroupWithFlags(flags, GroupResult.self) + let _group = Builtin.createTaskGroupWithFlags(flags, Void.self) var group = ThrowingDiscardingTaskGroup(group: _group) defer { Builtin.destroyTaskGroup(_group) } diff --git a/stdlib/public/Concurrency/TaskGroup.cpp b/stdlib/public/Concurrency/TaskGroup.cpp index 9cc8f4ea0712b..0d1d2be13d156 100644 --- a/stdlib/public/Concurrency/TaskGroup.cpp +++ b/stdlib/public/Concurrency/TaskGroup.cpp @@ -439,12 +439,12 @@ class TaskGroupBase : public TaskGroupTaskStatusRecord { /// by simultaneously decrementing one Pending and one Waiting tasks. /// /// This is used to atomically perform a waiting task completion. - /// The change is made 'relaxed' and may have to be retried. + /// The change is made with relaxed memory ordering. /// /// This can be safely used in a discarding task group as well, /// where the "ready" change will simply be ignored, since there /// are no ready bits to change. - bool statusCompletePendingReadyWaiting(TaskGroupStatus &old); + void statusCompletePendingReadyWaiting(TaskGroupStatus &old); /// Cancel the task group and all tasks within it. /// @@ -568,7 +568,11 @@ struct TaskGroupStatus { // so if we're in "discard results" mode, we must not decrement the ready count, // as there is no ready count in the status. change += group->isAccumulatingResults() ? oneReadyTask : 0; - return TaskGroupStatus{status - change}; + + TaskGroupStatus newStatus{status - change}; + SWIFT_TASK_GROUP_DEBUG_LOG(group, "completingPendingReadyWaiting %s", + newStatus.to_string(group).c_str()); + return newStatus; } TaskGroupStatus completingPendingReady(const TaskGroupBase* _Nonnull group) { @@ -669,11 +673,12 @@ struct TaskGroupStatus { }; }; -bool TaskGroupBase::statusCompletePendingReadyWaiting(TaskGroupStatus &old) { - return status.compare_exchange_strong( +void TaskGroupBase::statusCompletePendingReadyWaiting(TaskGroupStatus &old) { + while (!status.compare_exchange_weak( old.status, old.completingPendingReadyWaiting(this).status, /*success*/ std::memory_order_relaxed, - /*failure*/ std::memory_order_relaxed); + /*failure*/ std::memory_order_relaxed)) { + } // Loop until the compare_exchange succeeds } AsyncTask *TaskGroupBase::claimWaitingTask() { @@ -681,12 +686,11 @@ AsyncTask *TaskGroupBase::claimWaitingTask() { "attempted to claim waiting task but status indicates no waiting " "task is present!"); - auto waitingTask = waitQueue.load(std::memory_order_acquire); - if (!waitQueue.compare_exchange_strong(waitingTask, nullptr, - std::memory_order_release, - std::memory_order_relaxed)) { - swift_Concurrency_fatalError(0, "Failed to claim waitingTask!"); - } + auto waitingTask = waitQueue.exchange(nullptr, std::memory_order_acquire); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "claimed waiting task %p", waitingTask); + if (!waitingTask) + swift_Concurrency_fatalError(0, "Claimed NULL waitingTask!"); + return waitingTask; } void TaskGroupBase::runWaitingTask(PreparedWaitingTask prepared) { @@ -737,13 +741,19 @@ uint64_t TaskGroupBase::pendingTasks() const { TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeAcquire() { auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_acquire); - return TaskGroupStatus{old | TaskGroupStatus::waiting}; + TaskGroupStatus newStatus{old | TaskGroupStatus::waiting}; + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusMarkWaitingAssumeAcquire %s", + newStatus.to_string(this).c_str()); + return newStatus; } TaskGroupStatus TaskGroupBase::statusMarkWaitingAssumeRelease() { auto old = status.fetch_or(TaskGroupStatus::waiting, std::memory_order_release); - return TaskGroupStatus{old | TaskGroupStatus::waiting}; + TaskGroupStatus newStatus{old | TaskGroupStatus::waiting}; + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusMarkWaitingAssumeRelease %s", + newStatus.to_string(this).c_str()); + return newStatus; } /// Add a single pending task to the status counter. @@ -786,6 +796,8 @@ TaskGroupStatus TaskGroupBase::statusAddPendingTaskAssumeRelaxed(bool unconditio TaskGroupStatus TaskGroupBase::statusRemoveWaitingRelease() { auto old = status.fetch_and(~TaskGroupStatus::waiting, std::memory_order_release); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusRemoveWaitingRelease %s", + old.to_string(this).c_str()); return TaskGroupStatus{old}; } @@ -793,6 +805,9 @@ bool TaskGroupBase::statusCancel() { /// The cancelled bit is always the same, the first one, between all task group implementations: const uint64_t cancelled = TaskGroupStatus::cancelled; auto old = status.fetch_or(cancelled, std::memory_order_relaxed); + SWIFT_TASK_GROUP_DEBUG_LOG( + this, "statusCancel %s", + TaskGroupStatus{old | cancelled}.to_string(this).c_str()); // return if the status was already cancelled before we flipped it or not return old & cancelled; @@ -827,6 +842,8 @@ class AccumulatingTaskGroup: public TaskGroupBase { auto old = status.fetch_add(TaskGroupStatus::oneReadyTask, std::memory_order_acquire); auto s = TaskGroupStatus{old + TaskGroupStatus::oneReadyTask}; + SWIFT_TASK_GROUP_DEBUG_LOG(this, "statusMarkWaitingAssumeRelease %s", + s.to_string(this).c_str()); assert(s.readyTasks(this) <= s.pendingTasks(this)); return s; } @@ -880,23 +897,17 @@ class DiscardingTaskGroup: public TaskGroupBase { return TaskGroupStatus{status.load(std::memory_order_acquire)}; } - /// Compare-and-set old status to a status derived from the old one, - /// by simultaneously decrementing one Pending and one Waiting tasks. - /// - /// This is used to atomically perform a waiting task completion. - bool statusCompletePendingReadyWaiting(TaskGroupStatus &old) { - return status.compare_exchange_strong( - old.status, old.completingPendingReadyWaiting(this).status, - /*success*/ std::memory_order_relaxed, - /*failure*/ std::memory_order_relaxed); - } - /// Decrement the pending status count. /// Returns the *assumed* new status, including the just performed -1. TaskGroupStatus statusCompletePendingAssumeRelease() { auto old = status.fetch_sub(TaskGroupStatus::onePendingTask, std::memory_order_release); assert(TaskGroupStatus{old}.pendingTasks(this) > 0 && "attempted to decrement pending count when it was 0 already"); + SWIFT_TASK_GROUP_DEBUG_LOG( + this, "statusComplete = %s", + TaskGroupStatus{status.load(std::memory_order_relaxed)} + .to_string(this) + .c_str()); return TaskGroupStatus{old - TaskGroupStatus::onePendingTask}; } @@ -1323,6 +1334,8 @@ void AccumulatingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *contex // ==== a) has waiting task, so let us complete it right away if (assumed.hasWaitingTask()) { auto waitingTask = claimWaitingTask(); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "offer, waitingTask = %p", waitingTask); + assert(waitingTask); auto prepared = prepareWaitingTaskWithTask( /*complete=*/waitingTask, /*with=*/completedTask, assumed, hadErrorResult); @@ -1480,14 +1493,7 @@ void DiscardingTaskGroup::offer(AsyncTask *completedTask, AsyncContext *context) // We grab the waiting task while holding the group lock, because this // allows a single task to get the waiting task and attempt to complete it. // As another offer gets to run, it will have either a different waiting task, or no waiting task at all. - auto waitingTask = waitQueue.load(std::memory_order_acquire); - if (!waitQueue.compare_exchange_strong(waitingTask, nullptr, - std::memory_order_release, - std::memory_order_relaxed)) { - swift_Concurrency_fatalError(0, "Failed to claim waitingTask!"); - } - assert(waitingTask && "status claimed to have waitingTask but waitQueue was empty!"); - + auto waitingTask = claimWaitingTask(); SWIFT_TASK_GROUP_DEBUG_LOG(this, "offer, last pending task completed successfully, resume waitingTask:%p with completedTask:%p", waitingTask, completedTask); @@ -1558,8 +1564,11 @@ TaskGroupBase::PreparedWaitingTask TaskGroupBase::prepareWaitingTaskWithTask( bool hadErrorResult, bool alreadyDecremented, bool taskWasRetained) { - SWIFT_TASK_GROUP_DEBUG_LOG(this, "resume, waitingTask = %p, completedTask = %p, alreadyDecremented:%d, error:%d", - waitingTask, alreadyDecremented, hadErrorResult, completedTask); + SWIFT_TASK_GROUP_DEBUG_LOG(this, + "resume, waitingTask = %p, completedTask = %p, " + "alreadyDecremented:%d, error:%d", + waitingTask, completedTask, alreadyDecremented, + hadErrorResult); assert(waitingTask && "waitingTask must not be null when attempting to resume it"); assert(assumed.hasWaitingTask()); #if SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL @@ -1579,9 +1588,8 @@ TaskGroupBase::PreparedWaitingTask TaskGroupBase::prepareWaitingTaskWithTask( enqueueCompletedTask(completedTask, hadErrorResult); return {nullptr}; #else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */ - if (!alreadyDecremented) { - (void) statusCompletePendingReadyWaiting(assumed); - } + if (!alreadyDecremented) + statusCompletePendingReadyWaiting(assumed); // Populate the waiting task with value from completedTask. auto result = PollResult::get(completedTask, hadErrorResult); @@ -1643,9 +1651,8 @@ DiscardingTaskGroup::prepareWaitingTaskWithError(AsyncTask *waitingTask, _enqueueRawError(this, &readyQueue, error); return {nullptr}; #else /* SWIFT_CONCURRENCY_TASK_TO_THREAD_MODEL */ - if (!alreadyDecremented) { + if (!alreadyDecremented) statusCompletePendingReadyWaiting(assumed); - } // Run the task. auto result = PollResult::getError(error); @@ -1796,77 +1803,81 @@ reevaluate_if_taskgroup_has_results:; auto waitHead = waitQueue.load(std::memory_order_acquire); // ==== 2) Ready task was polled, return with it immediately ----------------- - if (assumed.readyTasks(this)) { + while (assumed.readyTasks(this)) { + // We loop when the compare_exchange fails. SWIFT_TASK_DEBUG_LOG("poll group = %p, tasks .ready = %d, .pending = %llu", this, assumed.readyTasks(this), assumed.pendingTasks(this)); auto assumedStatus = assumed.status; auto newStatus = TaskGroupStatus{assumedStatus}; - if (status.compare_exchange_strong( - assumedStatus, newStatus.completingPendingReadyWaiting(this).status, - /*success*/ std::memory_order_release, - /*failure*/ std::memory_order_acquire)) { - - // We're going back to running the task, so if we suspended before, - // we need to flag it as running again. - if (hasSuspended) { - waitingTask->flagAsRunning(); - } + if (!status.compare_exchange_weak( + assumedStatus, newStatus.completingPendingReadyWaiting(this).status, + /*success*/ std::memory_order_release, + /*failure*/ std::memory_order_acquire)) { + assumed = TaskGroupStatus{assumedStatus}; + continue; // We raced with something, try again. + } + SWIFT_TASK_DEBUG_LOG("poll, after CAS: %s", status.to_string().c_str()); - // Success! We are allowed to poll. - ReadyQueueItem item; - bool taskDequeued = readyQueue.dequeue(item); - assert(taskDequeued); (void) taskDequeued; - - auto futureFragment = - item.getStatus() == ReadyStatus::RawError ? - nullptr : - item.getTask()->futureFragment(); - - // Store the task in the result, so after we're done processing it may - // be swift_release'd; we kept it alive while it was in the readyQueue by - // an additional retain issued as we enqueued it there. - - // Note that the task was detached from the task group when it - // completed, so we don't need to do that bit of record-keeping here. - - switch (item.getStatus()) { - case ReadyStatus::Success: - // Immediately return the polled value - result.status = PollStatus::Success; - result.storage = futureFragment->getStoragePtr(); - result.successType = futureFragment->getResultType(); - result.retainedTask = item.getTask(); - assert(result.retainedTask && "polled a task, it must be not null"); - _swift_tsan_acquire(static_cast(result.retainedTask)); - unlock(); - return result; + // We're going back to running the task, so if we suspended before, + // we need to flag it as running again. + if (hasSuspended) { + waitingTask->flagAsRunning(); + } - case ReadyStatus::Error: - // Immediately return the polled value - result.status = PollStatus::Error; - result.storage = - reinterpret_cast(futureFragment->getError()); - result.successType = ResultTypeInfo(); - result.retainedTask = item.getTask(); - assert(result.retainedTask && "polled a task, it must be not null"); - _swift_tsan_acquire(static_cast(result.retainedTask)); - unlock(); - return result; + // Success! We are allowed to poll. + ReadyQueueItem item; + bool taskDequeued = readyQueue.dequeue(item); + assert(taskDequeued); (void) taskDequeued; + + auto futureFragment = + item.getStatus() == ReadyStatus::RawError ? + nullptr : + item.getTask()->futureFragment(); + + // Store the task in the result, so after we're done processing it may + // be swift_release'd; we kept it alive while it was in the readyQueue by + // an additional retain issued as we enqueued it there. + + // Note that the task was detached from the task group when it + // completed, so we don't need to do that bit of record-keeping here. + + switch (item.getStatus()) { + case ReadyStatus::Success: + // Immediately return the polled value + result.status = PollStatus::Success; + result.storage = futureFragment->getStoragePtr(); + result.successType = futureFragment->getResultType(); + result.retainedTask = item.getTask(); + assert(result.retainedTask && "polled a task, it must be not null"); + _swift_tsan_acquire(static_cast(result.retainedTask)); + unlock(); + return result; + + case ReadyStatus::Error: + // Immediately return the polled value + result.status = PollStatus::Error; + result.storage = + reinterpret_cast(futureFragment->getError()); + result.successType = ResultTypeInfo(); + result.retainedTask = item.getTask(); + assert(result.retainedTask && "polled a task, it must be not null"); + _swift_tsan_acquire(static_cast(result.retainedTask)); + unlock(); + return result; - case ReadyStatus::Empty: - result.status = PollStatus::Empty; - result.storage = nullptr; - result.retainedTask = nullptr; - result.successType = this->successType; - unlock(); - return result; + case ReadyStatus::Empty: + result.status = PollStatus::Empty; + result.storage = nullptr; + result.retainedTask = nullptr; + result.successType = this->successType; + unlock(); + return result; - case ReadyStatus::RawError: - swift_Concurrency_fatalError(0, "accumulating task group should never use raw-errors!"); - } - swift_Concurrency_fatalError(0, "must return result when status compare-and-swap was successful"); - } // else, we failed status-cas (some other waiter claimed a ready pending task, try again) + case ReadyStatus::RawError: + swift_Concurrency_fatalError(0, "accumulating task group should never use raw-errors!"); + } + swift_Concurrency_fatalError(0, "must return result when status compare-and-swap was successful"); } // ==== 3) Add to wait queue ------------------------------------------------- @@ -1878,7 +1889,9 @@ reevaluate_if_taskgroup_has_results:; } while (true) { // Put the waiting task at the beginning of the wait queue. - SWIFT_TASK_GROUP_DEBUG_LOG(this, "WATCH OUT, SET WAITER ONTO waitQueue.head = %p", waitQueue.load(std::memory_order_relaxed)); + SWIFT_TASK_GROUP_DEBUG_LOG( + this, "WATCH OUT, SET WAITER %p ONTO waitQueue.head = %p", waitingTask, + waitQueue.load(std::memory_order_relaxed)); if (waitQueue.compare_exchange_weak( waitHead, waitingTask, /*success*/ std::memory_order_release, @@ -2029,6 +2042,7 @@ void TaskGroupBase::waitAll(SwiftError* bodyError, AsyncTask *waitingTask, if (bodyError && isDiscardingResults() && readyQueue.isEmpty()) { auto discardingGroup = asDiscardingImpl(this); auto readyItem = ReadyQueueItem::getRawError(discardingGroup, bodyError); + SWIFT_TASK_GROUP_DEBUG_LOG(this, "enqueue %#" PRIxPTR, readyItem.storage); readyQueue.enqueue(readyItem); } diff --git a/test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift b/test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift new file mode 100644 index 0000000000000..96268d37c0914 --- /dev/null +++ b/test/Concurrency/Runtime/async_taskgroup_cancellation_race.swift @@ -0,0 +1,59 @@ +// RUN: %target-run-simple-swift + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: libdispatch +// REQUIRES: concurrency_runtime +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime +// UNSUPPORTED: back_deploy_concurrency +// UNSUPPORTED: freestanding + +func unorderedResults( + _ fns: [@Sendable () async -> R]) -> (Task<(), Never>, AsyncStream) { + var capturedContinuation: AsyncStream.Continuation? = nil + let stream = AsyncStream { continuation in + capturedContinuation = continuation + } + + guard let capturedContinuation = capturedContinuation else { + fatalError("failed to capture continuation") + } + + let task = Task.detached { + await withTaskGroup(of: Void.self) { group in + for fn in fns { + group.addTask { + let _ = capturedContinuation.yield(await fn()) + } + } + await group.waitForAll() + } + capturedContinuation.finish() + } + + let result = (task, stream) + + return result + } + +var fns: [@Sendable () async -> String] = [ + { + try? await Task.sleep(nanoseconds: .random(in: 0..<50000)) + return "hello" + } +] + +fns.append(fns[0]) +fns.append(fns[0]) + +// This is a race that will crash or trigger an assertion failure if there's an +// issue. If we get to the end then we pass. +for _ in 0..<1000 { + let (t, s) = unorderedResults(fns) + + for try await x in s { + _ = x + if Bool.random() { t.cancel() } + } +} diff --git a/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift b/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift index cacde94631888..bd5f54e9e7ff0 100644 --- a/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift +++ b/test/Concurrency/Runtime/async_taskgroup_discarding_neverConsumingTasks.swift @@ -85,9 +85,24 @@ func test_discardingTaskGroup_neverConsume(sleepBeforeGroupWaitAll: Duration) as print("all tasks: \(allTasks)") } +func test_discardingTaskGroup_bigReturn() async { + print(">>> \(#function)") + + // Test returning a very large value to ensure we don't overflow memory. + let array = await withDiscardingTaskGroup { group in + group.addTask {} + try? await Task.sleep(until: .now + .milliseconds(100), clock: .continuous) + return InlineArray<32768, Int>(repeating: 12345) + } + + // CHECK: Huge return value produced: 12345 12345 + print("Huge return value produced:", array[0], array[32767]) +} + @main struct Main { static func main() async { await test_discardingTaskGroup_neverConsume() await test_discardingTaskGroup_neverConsume(sleepBeforeGroupWaitAll: .milliseconds(500)) + await test_discardingTaskGroup_bigReturn() } }