diff --git a/Package.swift b/Package.swift index c43747c0..4f716a85 100644 --- a/Package.swift +++ b/Package.swift @@ -16,9 +16,12 @@ let package = Package( .library(name: "_CAsyncSequenceValidationSupport", type: .static, targets: ["AsyncSequenceValidation"]), .library(name: "AsyncAlgorithms_XCTest", targets: ["AsyncAlgorithms_XCTest"]), ], - dependencies: [], + dependencies: [.package(url: "https://github.com/apple/swift-collections.git", .upToNextMajor(from: "1.0.3"))], targets: [ - .target(name: "AsyncAlgorithms"), + .target( + name: "AsyncAlgorithms", + dependencies: [.product(name: "Collections", package: "swift-collections")] + ), .target( name: "AsyncSequenceValidation", dependencies: ["_CAsyncSequenceValidationSupport", "AsyncAlgorithms"]), diff --git a/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md b/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md index 5121c769..19eed41c 100644 --- a/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md +++ b/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md @@ -51,7 +51,7 @@ public final class AsyncThrowingChannel: Asyn } ``` -Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume. +Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume. The calls to `send(:)` and `next()` will immediately resume when their supporting task is cancelled, other operations from other tasks will remain active. ```swift let channel = AsyncChannel() diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index facdaadf..d53a7bac 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -9,6 +9,8 @@ // //===----------------------------------------------------------------------===// +import OrderedCollections + /// A channel for sending elements from one task to another with back pressure. /// /// The `AsyncChannel` class is intended to be used as a communication type between tasks, @@ -34,14 +36,17 @@ public final class AsyncChannel: AsyncSequence, Sendable { guard active else { return nil } + let generation = channel.establish() - let value: Element? = await withTaskCancellationHandler { [channel] in - channel.cancel(generation) + let nextTokenStatus = ManagedCriticalState(.new) + + let value = await withTaskCancellationHandler { [channel] in + channel.cancelNext(nextTokenStatus, generation) } operation: { - await channel.next(generation) + await channel.next(nextTokenStatus, generation) } - - if let value = value { + + if let value { return value } else { active = false @@ -50,68 +55,49 @@ public final class AsyncChannel: AsyncSequence, Sendable { } } - struct Awaiting: Hashable { + typealias Pending = ChannelToken?, Never>> + typealias Awaiting = ChannelToken> + + struct ChannelToken: Hashable { var generation: Int - var continuation: UnsafeContinuation? - let cancelled: Bool - - init(generation: Int, continuation: UnsafeContinuation) { + var continuation: Continuation? + + init(generation: Int, continuation: Continuation) { self.generation = generation self.continuation = continuation - cancelled = false } - + init(placeholder generation: Int) { self.generation = generation self.continuation = nil - cancelled = false } - - init(cancelled generation: Int) { - self.generation = generation - self.continuation = nil - cancelled = true - } - + func hash(into hasher: inout Hasher) { hasher.combine(generation) } - - static func == (_ lhs: Awaiting, _ rhs: Awaiting) -> Bool { + + static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool { return lhs.generation == rhs.generation } } + + enum ChannelTokenStatus: Equatable { + case new + case cancelled + } enum Emission { case idle - case pending([UnsafeContinuation?, Never>]) - case awaiting(Set) - - mutating func cancel(_ generation: Int) -> UnsafeContinuation? { - switch self { - case .awaiting(var awaiting): - let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation - if awaiting.isEmpty { - self = .idle - } else { - self = .awaiting(awaiting) - } - return continuation - case .idle: - self = .awaiting([Awaiting(cancelled: generation)]) - return nil - default: - return nil - } - } + case pending(OrderedSet) + case awaiting(OrderedSet) + case finished } struct State { var emission: Emission = .idle var generation = 0 - var terminal = false } - + let state = ManagedCriticalState(State()) /// Create a new `AsyncChannel` given an element type. @@ -123,22 +109,44 @@ public final class AsyncChannel: AsyncSequence, Sendable { return state.generation } } - - func cancel(_ generation: Int) { - state.withCriticalRegion { state in - state.emission.cancel(generation) + + func cancelNext(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation? in + let continuation: UnsafeContinuation? + + switch state.emission { + case .awaiting(var nexts): + continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation + if nexts.isEmpty { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + default: + continuation = nil + } + + nextTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation }?.resume(returning: nil) } - - func next(_ generation: Int) async -> Element? { - return await withUnsafeContinuation { continuation in + + func next(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) async -> Element? { + return await withUnsafeContinuation { (continuation: UnsafeContinuation) in var cancelled = false var terminal = false state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - if state.terminal { - terminal = true + + if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + cancelled = true return nil } + switch state.emission { case .idle: state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) @@ -150,94 +158,124 @@ public final class AsyncChannel: AsyncSequence, Sendable { } else { state.emission = .pending(sends) } - return UnsafeResumption(continuation: send, success: continuation) + return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): - if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { - nexts.remove(Awaiting(placeholder: generation)) - cancelled = true - } - if nexts.isEmpty { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } + nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation)) + state.emission = .awaiting(nexts) + return nil + case .finished: + terminal = true return nil } }?.resume() + if cancelled || terminal { continuation.resume(returning: nil) } } } - - func terminateAll() { - let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in - if state.terminal { - return ([], []) - } - state.terminal = true + + func cancelSend(_ sendTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation?, Never>? in + let continuation: UnsafeContinuation?, Never>? + switch state.emission { - case .idle: - return ([], []) - case .pending(let nexts): - state.emission = .idle - return (nexts, []) - case .awaiting(let nexts): - state.emission = .idle - return ([], nexts) + case .pending(var sends): + let send = sends.remove(Pending(placeholder: generation)) + if sends.isEmpty { + state.emission = .idle + } else { + state.emission = .pending(sends) + } + continuation = send?.continuation + default: + continuation = nil } - } - for send in sends { - send.resume(returning: nil) - } - for next in nexts { - next.continuation?.resume(returning: nil) - } + + sendTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation + }?.resume(returning: nil) } - - func _send(_ element: Element) async { - await withTaskCancellationHandler { - terminateAll() - } operation: { - let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in - state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - if state.terminal { - return UnsafeResumption(continuation: continuation, success: nil) - } - switch state.emission { - case .idle: - state.emission = .pending([continuation]) - return nil - case .pending(var sends): - sends.append(continuation) - state.emission = .pending(sends) - return nil - case .awaiting(var nexts): - let next = nexts.removeFirst().continuation - if nexts.count == 0 { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - return UnsafeResumption(continuation: continuation, success: next) + + func send(_ sendTokenStatus: ManagedCriticalState, _ generation: Int, _ element: Element) async { + let continuation = await withUnsafeContinuation { continuation in + state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + return UnsafeResumption(continuation: continuation, success: nil) + } + + switch state.emission { + case .idle: + state.emission = .pending([Pending(generation: generation, continuation: continuation)]) + return nil + case .pending(var sends): + sends.updateOrAppend(Pending(generation: generation, continuation: continuation)) + state.emission = .pending(sends) + return nil + case .awaiting(var nexts): + let next = nexts.removeFirst().continuation + if nexts.count == 0 { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) } - }?.resume() - } - continuation?.resume(returning: element) + return UnsafeResumption(continuation: continuation, success: next) + case .finished: + return UnsafeResumption(continuation: continuation, success: nil) + } + }?.resume() } + continuation?.resume(returning: element) } - + /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made /// or when a call to `finish()` is made from another Task. /// If the channel is already finished then this returns immediately + /// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active. public func send(_ element: Element) async { - await _send(element) + let generation = establish() + let sendTokenStatus = ManagedCriticalState(.new) + + await withTaskCancellationHandler { [weak self] in + self?.cancelSend(sendTokenStatus, generation) + } operation: { + await send(sendTokenStatus, generation, element) + } } /// Send a finish to all awaiting iterations. /// All subsequent calls to `next(_:)` will resume immediately. public func finish() { - terminateAll() + let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet, OrderedSet) in + let result: (OrderedSet, OrderedSet) + + switch state.emission { + case .idle: + result = ([], []) + case .pending(let nexts): + result = (nexts, []) + case .awaiting(let nexts): + result = ([], nexts) + case .finished: + result = ([], []) + } + + state.emission = .finished + + return result + } + for send in sends { + send.continuation?.resume(returning: nil) + } + for next in nexts { + next.continuation?.resume(returning: nil) + } } /// Create an `Iterator` for iteration of an `AsyncChannel` diff --git a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift index 5ce68961..473482ea 100644 --- a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift @@ -9,6 +9,8 @@ // //===----------------------------------------------------------------------===// +import OrderedCollections + /// An error-throwing channel for sending elements from on task to another with back pressure. /// /// The `AsyncThrowingChannel` class is intended to be used as a communication types between tasks, @@ -32,12 +34,15 @@ public final class AsyncThrowingChannel: Asyn guard active else { return nil } + let generation = channel.establish() + let nextTokenStatus = ManagedCriticalState(.new) + do { - let value: Element? = try await withTaskCancellationHandler { [channel] in - channel.cancel(generation) + let value = try await withTaskCancellationHandler { [channel] in + channel.cancelNext(nextTokenStatus, generation) } operation: { - try await channel.next(generation) + try await channel.next(nextTokenStatus, generation) } if let value = value { @@ -52,39 +57,39 @@ public final class AsyncThrowingChannel: Asyn } } } - - struct Awaiting: Hashable { + + typealias Pending = ChannelToken?, Never>> + typealias Awaiting = ChannelToken> + + struct ChannelToken: Hashable { var generation: Int - var continuation: UnsafeContinuation? - let cancelled: Bool - - init(generation: Int, continuation: UnsafeContinuation) { + var continuation: Continuation? + + init(generation: Int, continuation: Continuation) { self.generation = generation self.continuation = continuation - cancelled = false } - + init(placeholder generation: Int) { self.generation = generation self.continuation = nil - cancelled = false } - - init(cancelled generation: Int) { - self.generation = generation - self.continuation = nil - cancelled = true - } - + func hash(into hasher: inout Hasher) { hasher.combine(generation) } - - static func == (_ lhs: Awaiting, _ rhs: Awaiting) -> Bool { + + static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool { return lhs.generation == rhs.generation } } + + enum ChannelTokenStatus: Equatable { + case new + case cancelled + } + enum Termination { case finished case failed(Error) @@ -92,32 +97,9 @@ public final class AsyncThrowingChannel: Asyn enum Emission { case idle - case pending([UnsafeContinuation?, Never>]) - case awaiting(Set) + case pending(OrderedSet) + case awaiting(OrderedSet) case terminated(Termination) - - var isTerminated: Bool { - guard case .terminated = self else { return false } - return true - } - - mutating func cancel(_ generation: Int) -> UnsafeContinuation? { - switch self { - case .awaiting(var awaiting): - let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation - if awaiting.isEmpty { - self = .idle - } else { - self = .awaiting(awaiting) - } - return continuation - case .idle: - self = .awaiting([Awaiting(cancelled: generation)]) - return nil - default: - return nil - } - } } struct State { @@ -135,19 +117,45 @@ public final class AsyncThrowingChannel: Asyn return state.generation } } - - func cancel(_ generation: Int) { - state.withCriticalRegion { state in - state.emission.cancel(generation) + + func cancelNext(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation? in + let continuation: UnsafeContinuation? + + switch state.emission { + case .awaiting(var nexts): + continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation + if nexts.isEmpty { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + default: + continuation = nil + } + + nextTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation }?.resume(returning: nil) } - func next(_ generation: Int) async throws -> Element? { - return try await withUnsafeThrowingContinuation { continuation in + func next(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) async throws -> Element? { + return try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in var cancelled = false var potentialTermination: Termination? state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + cancelled = true + return nil + } + switch state.emission { case .idle: state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) @@ -159,17 +167,10 @@ public final class AsyncThrowingChannel: Asyn } else { state.emission = .pending(sends) } - return UnsafeResumption(continuation: send, success: continuation) + return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): - if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { - nexts.remove(Awaiting(placeholder: generation)) - cancelled = true - } - if nexts.isEmpty { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } + nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation)) + state.emission = .awaiting(nexts) return nil case .terminated(let termination): potentialTermination = termination @@ -196,8 +197,67 @@ public final class AsyncThrowingChannel: Asyn } } + func cancelSend(_ sendTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation?, Never>? in + let continuation: UnsafeContinuation?, Never>? + + switch state.emission { + case .pending(var sends): + let send = sends.remove(Pending(placeholder: generation)) + if sends.isEmpty { + state.emission = .idle + } else { + state.emission = .pending(sends) + } + continuation = send?.continuation + default: + continuation = nil + } + + sendTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation + }?.resume(returning: nil) + } + + func send(_ sendTokenStatus: ManagedCriticalState, _ generation: Int, _ element: Element) async { + let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in + state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + return UnsafeResumption(continuation: continuation, success: nil) + } + + switch state.emission { + case .idle: + state.emission = .pending([Pending(generation: generation, continuation: continuation)]) + return nil + case .pending(var sends): + sends.updateOrAppend(Pending(generation: generation, continuation: continuation)) + state.emission = .pending(sends) + return nil + case .awaiting(var nexts): + let next = nexts.removeFirst().continuation + if nexts.count == 0 { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + return UnsafeResumption(continuation: continuation, success: next) + case .terminated: + return UnsafeResumption(continuation: continuation, success: nil) + } + }?.resume() + } + continuation?.resume(returning: element) + } + func terminateAll(error: Failure? = nil) { - let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in + let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet, OrderedSet) in let nextState: Emission if let error = error { @@ -222,7 +282,7 @@ public final class AsyncThrowingChannel: Asyn } for send in sends { - send.resume(returning: nil) + send.continuation?.resume(returning: nil) } if let error = error { @@ -234,45 +294,21 @@ public final class AsyncThrowingChannel: Asyn next.continuation?.resume(returning: nil) } } - - } - - func _send(_ element: Element) async { - await withTaskCancellationHandler { - terminateAll() - } operation: { - let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in - state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - switch state.emission { - case .idle: - state.emission = .pending([continuation]) - return nil - case .pending(var sends): - sends.append(continuation) - state.emission = .pending(sends) - return nil - case .awaiting(var nexts): - let next = nexts.removeFirst().continuation - if nexts.count == 0 { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - return UnsafeResumption(continuation: continuation, success: next) - case .terminated: - return UnsafeResumption(continuation: continuation, success: nil) - } - }?.resume() - } - continuation?.resume(returning: element) - } } /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made /// or when a call to `finish()`/`fail(_:)` is made from another Task. /// If the channel is already finished then this returns immediately + /// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active. public func send(_ element: Element) async { - await _send(element) + let generation = establish() + let sendTokenStatus = ManagedCriticalState(.new) + + await withTaskCancellationHandler { [weak self] in + self?.cancelSend(sendTokenStatus, generation) + } operation: { + await send(sendTokenStatus, generation, element) + } } /// Send an error to all awaiting iterations. diff --git a/Sources/AsyncAlgorithms/UnsafeResumption.swift b/Sources/AsyncAlgorithms/UnsafeResumption.swift index 9eb28c5f..d87987ac 100644 --- a/Sources/AsyncAlgorithms/UnsafeResumption.swift +++ b/Sources/AsyncAlgorithms/UnsafeResumption.swift @@ -1,22 +1,22 @@ struct UnsafeResumption { - let continuation: UnsafeContinuation + let continuation: UnsafeContinuation? let result: Result - init(continuation: UnsafeContinuation, result: Result) { + init(continuation: UnsafeContinuation?, result: Result) { self.continuation = continuation self.result = result } - init(continuation: UnsafeContinuation, success: Success) { + init(continuation: UnsafeContinuation?, success: Success) { self.init(continuation: continuation, result: .success(success)) } - init(continuation: UnsafeContinuation, failure: Failure) { + init(continuation: UnsafeContinuation?, failure: Failure) { self.init(continuation: continuation, result: .failure(failure)) } func resume() { - continuation.resume(with: result) + continuation?.resume(with: result) } } diff --git a/Tests/AsyncAlgorithmsTests/TestChannel.swift b/Tests/AsyncAlgorithmsTests/TestChannel.swift index 66fd1e1d..2d8797fa 100644 --- a/Tests/AsyncAlgorithmsTests/TestChannel.swift +++ b/Tests/AsyncAlgorithmsTests/TestChannel.swift @@ -227,7 +227,7 @@ final class TestChannel: XCTestCase { XCTAssertNil(value) } - func test_asyncChannel_resumes_send_when_task_is_cancelled() async { + func test_asyncChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async { let channel = AsyncChannel() let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true @@ -237,12 +237,21 @@ final class TestChannel: XCTestCase { notYetDone.fulfill() done.fulfill() } + + Task { + await channel.send(2) + } + wait(for: [notYetDone], timeout: 0.1) task.cancel() wait(for: [done], timeout: 1.0) + + var iterator = channel.makeAsyncIterator() + let received = await iterator.next() + XCTAssertEqual(received, 2) } - func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled() async { + func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async throws { let channel = AsyncThrowingChannel() let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true @@ -252,8 +261,17 @@ final class TestChannel: XCTestCase { notYetDone.fulfill() done.fulfill() } + + Task { + await channel.send(2) + } + wait(for: [notYetDone], timeout: 0.1) task.cancel() wait(for: [done], timeout: 1.0) + + var iterator = channel.makeAsyncIterator() + let received = try await iterator.next() + XCTAssertEqual(received, 2) } }