From e2920bb18d2a25373ac260aac82b84e1fe35c6f8 Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Thu, 28 Jul 2022 13:10:44 +0200 Subject: [PATCH 1/5] asyncChannel: introduce ChannelToken to model Pending and Awaiting --- Sources/AsyncAlgorithms/AsyncChannel.swift | 37 +++++++++++-------- .../AsyncAlgorithms/UnsafeResumption.swift | 10 ++--- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index facdaadf..57c299bd 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -50,41 +50,44 @@ public final class AsyncChannel: AsyncSequence, Sendable { } } - struct Awaiting: Hashable { + typealias Pending = ChannelToken?, Never>> + typealias Awaiting = ChannelToken> + + struct ChannelToken: Hashable { var generation: Int - var continuation: UnsafeContinuation? + var continuation: Continuation? let cancelled: Bool - - init(generation: Int, continuation: UnsafeContinuation) { + + init(generation: Int, continuation: Continuation) { self.generation = generation self.continuation = continuation cancelled = false } - + init(placeholder generation: Int) { self.generation = generation self.continuation = nil cancelled = false } - + init(cancelled generation: Int) { self.generation = generation self.continuation = nil cancelled = true } - + func hash(into hasher: inout Hasher) { hasher.combine(generation) } - - static func == (_ lhs: Awaiting, _ rhs: Awaiting) -> Bool { + + static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool { return lhs.generation == rhs.generation } } enum Emission { case idle - case pending([UnsafeContinuation?, Never>]) + case pending([Pending]) case awaiting(Set) mutating func cancel(_ generation: Int) -> UnsafeContinuation? { @@ -131,7 +134,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { } func next(_ generation: Int) async -> Element? { - return await withUnsafeContinuation { continuation in + return await withUnsafeContinuation { (continuation: UnsafeContinuation) in var cancelled = false var terminal = false state.withCriticalRegion { state -> UnsafeResumption?, Never>? in @@ -150,7 +153,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { } else { state.emission = .pending(sends) } - return UnsafeResumption(continuation: send, success: continuation) + return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { nexts.remove(Awaiting(placeholder: generation)) @@ -171,7 +174,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { } func terminateAll() { - let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in + let (sends, nexts) = state.withCriticalRegion { state -> ([Pending], Set) in if state.terminal { return ([], []) } @@ -188,7 +191,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { } } for send in sends { - send.resume(returning: nil) + send.continuation?.resume(returning: nil) } for next in nexts { next.continuation?.resume(returning: nil) @@ -196,6 +199,8 @@ public final class AsyncChannel: AsyncSequence, Sendable { } func _send(_ element: Element) async { + let generation = establish() + await withTaskCancellationHandler { terminateAll() } operation: { @@ -206,10 +211,10 @@ public final class AsyncChannel: AsyncSequence, Sendable { } switch state.emission { case .idle: - state.emission = .pending([continuation]) + state.emission = .pending([Pending(generation: generation, continuation: continuation)]) return nil case .pending(var sends): - sends.append(continuation) + sends.append(Pending(generation: generation, continuation: continuation)) state.emission = .pending(sends) return nil case .awaiting(var nexts): diff --git a/Sources/AsyncAlgorithms/UnsafeResumption.swift b/Sources/AsyncAlgorithms/UnsafeResumption.swift index 9eb28c5f..d87987ac 100644 --- a/Sources/AsyncAlgorithms/UnsafeResumption.swift +++ b/Sources/AsyncAlgorithms/UnsafeResumption.swift @@ -1,22 +1,22 @@ struct UnsafeResumption { - let continuation: UnsafeContinuation + let continuation: UnsafeContinuation? let result: Result - init(continuation: UnsafeContinuation, result: Result) { + init(continuation: UnsafeContinuation?, result: Result) { self.continuation = continuation self.result = result } - init(continuation: UnsafeContinuation, success: Success) { + init(continuation: UnsafeContinuation?, success: Success) { self.init(continuation: continuation, result: .success(success)) } - init(continuation: UnsafeContinuation, failure: Failure) { + init(continuation: UnsafeContinuation?, failure: Failure) { self.init(continuation: continuation, result: .failure(failure)) } func resume() { - continuation.resume(with: result) + continuation?.resume(with: result) } } From d578c0681b015bb0f922551f5aa98f475bb4cd57 Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Thu, 28 Jul 2022 17:16:26 +0200 Subject: [PATCH 2/5] asyncChannel: harmonize send and next cancellation --- Sources/AsyncAlgorithms/AsyncChannel.swift | 223 +++++++++++-------- Tests/AsyncAlgorithmsTests/TestChannel.swift | 15 +- 2 files changed, 137 insertions(+), 101 deletions(-) diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index 57c299bd..29a0bce3 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -34,14 +34,17 @@ public final class AsyncChannel: AsyncSequence, Sendable { guard active else { return nil } + let generation = channel.establish() - let value: Element? = await withTaskCancellationHandler { [channel] in - channel.cancel(generation) + let nextTokenStatus = ManagedCriticalState(.new) + + let value = await withTaskCancellationHandler { [channel] in + channel.cancelNext(nextTokenStatus, generation) } operation: { - await channel.next(generation) + await channel.next(nextTokenStatus, generation) } - - if let value = value { + + if let value { return value } else { active = false @@ -56,24 +59,15 @@ public final class AsyncChannel: AsyncSequence, Sendable { struct ChannelToken: Hashable { var generation: Int var continuation: Continuation? - let cancelled: Bool init(generation: Int, continuation: Continuation) { self.generation = generation self.continuation = continuation - cancelled = false } init(placeholder generation: Int) { self.generation = generation self.continuation = nil - cancelled = false - } - - init(cancelled generation: Int) { - self.generation = generation - self.continuation = nil - cancelled = true } func hash(into hasher: inout Hasher) { @@ -84,29 +78,16 @@ public final class AsyncChannel: AsyncSequence, Sendable { return lhs.generation == rhs.generation } } + + enum ChannelTokenStatus: Equatable { + case new + case cancelled + } enum Emission { case idle - case pending([Pending]) + case pending(Set) case awaiting(Set) - - mutating func cancel(_ generation: Int) -> UnsafeContinuation? { - switch self { - case .awaiting(var awaiting): - let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation - if awaiting.isEmpty { - self = .idle - } else { - self = .awaiting(awaiting) - } - return continuation - case .idle: - self = .awaiting([Awaiting(cancelled: generation)]) - return nil - default: - return nil - } - } } struct State { @@ -114,7 +95,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { var generation = 0 var terminal = false } - + let state = ManagedCriticalState(State()) /// Create a new `AsyncChannel` given an element type. @@ -126,18 +107,44 @@ public final class AsyncChannel: AsyncSequence, Sendable { return state.generation } } - - func cancel(_ generation: Int) { - state.withCriticalRegion { state in - state.emission.cancel(generation) + + func cancelNext(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation? in + let continuation: UnsafeContinuation? + + switch state.emission { + case .awaiting(var nexts): + continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation + if nexts.isEmpty { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + default: + continuation = nil + } + + nextTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation }?.resume(returning: nil) } - - func next(_ generation: Int) async -> Element? { + + func next(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) async -> Element? { return await withUnsafeContinuation { (continuation: UnsafeContinuation) in var cancelled = false var terminal = false state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + cancelled = true + return nil + } + if state.terminal { terminal = true return nil @@ -155,26 +162,93 @@ public final class AsyncChannel: AsyncSequence, Sendable { } return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): - if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { - nexts.remove(Awaiting(placeholder: generation)) - cancelled = true - } - if nexts.isEmpty { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } + nexts.update(with: Awaiting(generation: generation, continuation: continuation)) + state.emission = .awaiting(nexts) return nil } }?.resume() + if cancelled || terminal { continuation.resume(returning: nil) } } } + + func cancelSend(_ sendTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation?, Never>? in + let continuation: UnsafeContinuation?, Never>? + + switch state.emission { + case .pending(var sends): + let send = sends.remove(Pending(placeholder: generation)) + if sends.isEmpty { + state.emission = .idle + } else { + state.emission = .pending(sends) + } + continuation = send?.continuation + default: + continuation = nil + } + + sendTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation + }?.resume(returning: nil) + } + + func send(_ sendTokenStatus: ManagedCriticalState, _ generation: Int, _ element: Element) async { + let continuation = await withUnsafeContinuation { continuation in + state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled || state.terminal { + return UnsafeResumption(continuation: continuation, success: nil) + } + + switch state.emission { + case .idle: + state.emission = .pending([Pending(generation: generation, continuation: continuation)]) + return nil + case .pending(var sends): + sends.update(with: Pending(generation: generation, continuation: continuation)) + state.emission = .pending(sends) + return nil + case .awaiting(var nexts): + let next = nexts.removeFirst().continuation + if nexts.count == 0 { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + return UnsafeResumption(continuation: continuation, success: next) + } + }?.resume() + } + continuation?.resume(returning: element) + } + + /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made + /// or when a call to `finish()` is made from another Task. + /// If the channel is already finished then this returns immediately + public func send(_ element: Element) async { + let generation = establish() + let sendTokenStatus = ManagedCriticalState(.new) + + await withTaskCancellationHandler { [weak self] in + self?.cancelSend(sendTokenStatus, generation) + } operation: { + await send(sendTokenStatus, generation, element) + } + } - func terminateAll() { - let (sends, nexts) = state.withCriticalRegion { state -> ([Pending], Set) in + /// Send a finish to all awaiting iterations. + /// All subsequent calls to `next(_:)` will resume immediately. + public func finish() { + let (sends, nexts) = state.withCriticalRegion { state -> (Set, Set) in if state.terminal { return ([], []) } @@ -198,53 +272,6 @@ public final class AsyncChannel: AsyncSequence, Sendable { } } - func _send(_ element: Element) async { - let generation = establish() - - await withTaskCancellationHandler { - terminateAll() - } operation: { - let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in - state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - if state.terminal { - return UnsafeResumption(continuation: continuation, success: nil) - } - switch state.emission { - case .idle: - state.emission = .pending([Pending(generation: generation, continuation: continuation)]) - return nil - case .pending(var sends): - sends.append(Pending(generation: generation, continuation: continuation)) - state.emission = .pending(sends) - return nil - case .awaiting(var nexts): - let next = nexts.removeFirst().continuation - if nexts.count == 0 { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - return UnsafeResumption(continuation: continuation, success: next) - } - }?.resume() - } - continuation?.resume(returning: element) - } - } - - /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made - /// or when a call to `finish()` is made from another Task. - /// If the channel is already finished then this returns immediately - public func send(_ element: Element) async { - await _send(element) - } - - /// Send a finish to all awaiting iterations. - /// All subsequent calls to `next(_:)` will resume immediately. - public func finish() { - terminateAll() - } - /// Create an `Iterator` for iteration of an `AsyncChannel` public func makeAsyncIterator() -> Iterator { return Iterator(self) diff --git a/Tests/AsyncAlgorithmsTests/TestChannel.swift b/Tests/AsyncAlgorithmsTests/TestChannel.swift index 66fd1e1d..b6c34df0 100644 --- a/Tests/AsyncAlgorithmsTests/TestChannel.swift +++ b/Tests/AsyncAlgorithmsTests/TestChannel.swift @@ -227,19 +227,28 @@ final class TestChannel: XCTestCase { XCTAssertNil(value) } - func test_asyncChannel_resumes_send_when_task_is_cancelled() async { + func test_asyncChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async { let channel = AsyncChannel() let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true let done = expectation(description: "done") - let task = Task { + let task1 = Task { await channel.send(1) notYetDone.fulfill() done.fulfill() } + + Task { + await channel.send(2) + } + wait(for: [notYetDone], timeout: 0.1) - task.cancel() + task1.cancel() wait(for: [done], timeout: 1.0) + + var iterator = channel.makeAsyncIterator() + let received = await iterator.next() + XCTAssertEqual(received, 2) } func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled() async { From 0aaa8131cdb19b4a45cea337885eeb3b9ea113d7 Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Thu, 28 Jul 2022 17:22:29 +0200 Subject: [PATCH 3/5] asyncChannel: add .finished as an emission state --- Sources/AsyncAlgorithms/AsyncChannel.swift | 33 ++++++++++++---------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index 29a0bce3..1719a788 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -88,12 +88,12 @@ public final class AsyncChannel: AsyncSequence, Sendable { case idle case pending(Set) case awaiting(Set) + case finished } struct State { var emission: Emission = .idle var generation = 0 - var terminal = false } let state = ManagedCriticalState(State()) @@ -145,10 +145,6 @@ public final class AsyncChannel: AsyncSequence, Sendable { return nil } - if state.terminal { - terminal = true - return nil - } switch state.emission { case .idle: state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) @@ -165,6 +161,9 @@ public final class AsyncChannel: AsyncSequence, Sendable { nexts.update(with: Awaiting(generation: generation, continuation: continuation)) state.emission = .awaiting(nexts) return nil + case .finished: + terminal = true + return nil } }?.resume() @@ -205,7 +204,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { let continuation = await withUnsafeContinuation { continuation in state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled || state.terminal { + if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled { return UnsafeResumption(continuation: continuation, success: nil) } @@ -225,6 +224,8 @@ public final class AsyncChannel: AsyncSequence, Sendable { state.emission = .awaiting(nexts) } return UnsafeResumption(continuation: continuation, success: next) + case .finished: + return UnsafeResumption(continuation: continuation, success: nil) } }?.resume() } @@ -249,20 +250,22 @@ public final class AsyncChannel: AsyncSequence, Sendable { /// All subsequent calls to `next(_:)` will resume immediately. public func finish() { let (sends, nexts) = state.withCriticalRegion { state -> (Set, Set) in - if state.terminal { - return ([], []) - } - state.terminal = true + let result: (Set, Set) + switch state.emission { case .idle: - return ([], []) + result = ([], []) case .pending(let nexts): - state.emission = .idle - return (nexts, []) + result = (nexts, []) case .awaiting(let nexts): - state.emission = .idle - return ([], nexts) + result = ([], nexts) + case .finished: + result = ([], []) } + + state.emission = .finished + + return result } for send in sends { send.continuation?.resume(returning: nil) From bdd7bc7273a3b79bc088566505cbb63c8856d112 Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Thu, 28 Jul 2022 17:43:50 +0200 Subject: [PATCH 4/5] asyncThrowingChannel: harmonize send and next cancellation --- .../AsyncThrowingChannel.swift | 225 ++++++++++-------- Tests/AsyncAlgorithmsTests/TestChannel.swift | 15 +- 2 files changed, 141 insertions(+), 99 deletions(-) diff --git a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift index 5ce68961..a4a963bc 100644 --- a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift @@ -32,12 +32,15 @@ public final class AsyncThrowingChannel: Asyn guard active else { return nil } + let generation = channel.establish() + let nextTokenStatus = ManagedCriticalState(.new) + do { - let value: Element? = try await withTaskCancellationHandler { [channel] in - channel.cancel(generation) + let value = try await withTaskCancellationHandler { [channel] in + channel.cancelNext(nextTokenStatus, generation) } operation: { - try await channel.next(generation) + try await channel.next(nextTokenStatus, generation) } if let value = value { @@ -52,39 +55,39 @@ public final class AsyncThrowingChannel: Asyn } } } - - struct Awaiting: Hashable { + + typealias Pending = ChannelToken?, Never>> + typealias Awaiting = ChannelToken> + + struct ChannelToken: Hashable { var generation: Int - var continuation: UnsafeContinuation? - let cancelled: Bool - - init(generation: Int, continuation: UnsafeContinuation) { + var continuation: Continuation? + + init(generation: Int, continuation: Continuation) { self.generation = generation self.continuation = continuation - cancelled = false } - + init(placeholder generation: Int) { self.generation = generation self.continuation = nil - cancelled = false } - - init(cancelled generation: Int) { - self.generation = generation - self.continuation = nil - cancelled = true - } - + func hash(into hasher: inout Hasher) { hasher.combine(generation) } - - static func == (_ lhs: Awaiting, _ rhs: Awaiting) -> Bool { + + static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool { return lhs.generation == rhs.generation } } + + enum ChannelTokenStatus: Equatable { + case new + case cancelled + } + enum Termination { case finished case failed(Error) @@ -92,32 +95,9 @@ public final class AsyncThrowingChannel: Asyn enum Emission { case idle - case pending([UnsafeContinuation?, Never>]) + case pending(Set) case awaiting(Set) case terminated(Termination) - - var isTerminated: Bool { - guard case .terminated = self else { return false } - return true - } - - mutating func cancel(_ generation: Int) -> UnsafeContinuation? { - switch self { - case .awaiting(var awaiting): - let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation - if awaiting.isEmpty { - self = .idle - } else { - self = .awaiting(awaiting) - } - return continuation - case .idle: - self = .awaiting([Awaiting(cancelled: generation)]) - return nil - default: - return nil - } - } } struct State { @@ -135,19 +115,45 @@ public final class AsyncThrowingChannel: Asyn return state.generation } } - - func cancel(_ generation: Int) { - state.withCriticalRegion { state in - state.emission.cancel(generation) + + func cancelNext(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation? in + let continuation: UnsafeContinuation? + + switch state.emission { + case .awaiting(var nexts): + continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation + if nexts.isEmpty { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + default: + continuation = nil + } + + nextTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation }?.resume(returning: nil) } - func next(_ generation: Int) async throws -> Element? { - return try await withUnsafeThrowingContinuation { continuation in + func next(_ nextTokenStatus: ManagedCriticalState, _ generation: Int) async throws -> Element? { + return try await withUnsafeThrowingContinuation { (continuation: UnsafeContinuation) in var cancelled = false var potentialTermination: Termination? state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + cancelled = true + return nil + } + switch state.emission { case .idle: state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) @@ -159,17 +165,10 @@ public final class AsyncThrowingChannel: Asyn } else { state.emission = .pending(sends) } - return UnsafeResumption(continuation: send, success: continuation) + return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): - if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { - nexts.remove(Awaiting(placeholder: generation)) - cancelled = true - } - if nexts.isEmpty { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } + nexts.update(with: Awaiting(generation: generation, continuation: continuation)) + state.emission = .awaiting(nexts) return nil case .terminated(let termination): potentialTermination = termination @@ -196,8 +195,67 @@ public final class AsyncThrowingChannel: Asyn } } + func cancelSend(_ sendTokenStatus: ManagedCriticalState, _ generation: Int) { + state.withCriticalRegion { state -> UnsafeContinuation?, Never>? in + let continuation: UnsafeContinuation?, Never>? + + switch state.emission { + case .pending(var sends): + let send = sends.remove(Pending(placeholder: generation)) + if sends.isEmpty { + state.emission = .idle + } else { + state.emission = .pending(sends) + } + continuation = send?.continuation + default: + continuation = nil + } + + sendTokenStatus.withCriticalRegion { status in + if status == .new { + status = .cancelled + } + } + + return continuation + }?.resume(returning: nil) + } + + func send(_ sendTokenStatus: ManagedCriticalState, _ generation: Int, _ element: Element) async { + let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in + state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + + if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled { + return UnsafeResumption(continuation: continuation, success: nil) + } + + switch state.emission { + case .idle: + state.emission = .pending([Pending(generation: generation, continuation: continuation)]) + return nil + case .pending(var sends): + sends.update(with: Pending(generation: generation, continuation: continuation)) + state.emission = .pending(sends) + return nil + case .awaiting(var nexts): + let next = nexts.removeFirst().continuation + if nexts.count == 0 { + state.emission = .idle + } else { + state.emission = .awaiting(nexts) + } + return UnsafeResumption(continuation: continuation, success: next) + case .terminated: + return UnsafeResumption(continuation: continuation, success: nil) + } + }?.resume() + } + continuation?.resume(returning: element) + } + func terminateAll(error: Failure? = nil) { - let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in + let (sends, nexts) = state.withCriticalRegion { state -> (Set, Set) in let nextState: Emission if let error = error { @@ -222,7 +280,7 @@ public final class AsyncThrowingChannel: Asyn } for send in sends { - send.resume(returning: nil) + send.continuation?.resume(returning: nil) } if let error = error { @@ -234,45 +292,20 @@ public final class AsyncThrowingChannel: Asyn next.continuation?.resume(returning: nil) } } - - } - - func _send(_ element: Element) async { - await withTaskCancellationHandler { - terminateAll() - } operation: { - let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in - state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - switch state.emission { - case .idle: - state.emission = .pending([continuation]) - return nil - case .pending(var sends): - sends.append(continuation) - state.emission = .pending(sends) - return nil - case .awaiting(var nexts): - let next = nexts.removeFirst().continuation - if nexts.count == 0 { - state.emission = .idle - } else { - state.emission = .awaiting(nexts) - } - return UnsafeResumption(continuation: continuation, success: next) - case .terminated: - return UnsafeResumption(continuation: continuation, success: nil) - } - }?.resume() - } - continuation?.resume(returning: element) - } } /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made /// or when a call to `finish()`/`fail(_:)` is made from another Task. /// If the channel is already finished then this returns immediately public func send(_ element: Element) async { - await _send(element) + let generation = establish() + let sendTokenStatus = ManagedCriticalState(.new) + + await withTaskCancellationHandler { [weak self] in + self?.cancelSend(sendTokenStatus, generation) + } operation: { + await send(sendTokenStatus, generation, element) + } } /// Send an error to all awaiting iterations. diff --git a/Tests/AsyncAlgorithmsTests/TestChannel.swift b/Tests/AsyncAlgorithmsTests/TestChannel.swift index b6c34df0..2d8797fa 100644 --- a/Tests/AsyncAlgorithmsTests/TestChannel.swift +++ b/Tests/AsyncAlgorithmsTests/TestChannel.swift @@ -232,7 +232,7 @@ final class TestChannel: XCTestCase { let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true let done = expectation(description: "done") - let task1 = Task { + let task = Task { await channel.send(1) notYetDone.fulfill() done.fulfill() @@ -243,7 +243,7 @@ final class TestChannel: XCTestCase { } wait(for: [notYetDone], timeout: 0.1) - task1.cancel() + task.cancel() wait(for: [done], timeout: 1.0) var iterator = channel.makeAsyncIterator() @@ -251,7 +251,7 @@ final class TestChannel: XCTestCase { XCTAssertEqual(received, 2) } - func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled() async { + func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async throws { let channel = AsyncThrowingChannel() let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true @@ -261,8 +261,17 @@ final class TestChannel: XCTestCase { notYetDone.fulfill() done.fulfill() } + + Task { + await channel.send(2) + } + wait(for: [notYetDone], timeout: 0.1) task.cancel() wait(for: [done], timeout: 1.0) + + var iterator = channel.makeAsyncIterator() + let received = try await iterator.next() + XCTAssertEqual(received, 2) } } From be52a2d522cb0e175f90964d8025fca988909528 Mon Sep 17 00:00:00 2001 From: Thibault Wittemberg Date: Thu, 28 Jul 2022 17:49:54 +0200 Subject: [PATCH 5/5] channel: update documentation for cancellation --- Package.swift | 7 +++++-- .../AsyncAlgorithms.docc/Guides/Channel.md | 2 +- Sources/AsyncAlgorithms/AsyncChannel.swift | 15 +++++++++------ .../AsyncAlgorithms/AsyncThrowingChannel.swift | 13 ++++++++----- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/Package.swift b/Package.swift index c43747c0..4f716a85 100644 --- a/Package.swift +++ b/Package.swift @@ -16,9 +16,12 @@ let package = Package( .library(name: "_CAsyncSequenceValidationSupport", type: .static, targets: ["AsyncSequenceValidation"]), .library(name: "AsyncAlgorithms_XCTest", targets: ["AsyncAlgorithms_XCTest"]), ], - dependencies: [], + dependencies: [.package(url: "https://github.com/apple/swift-collections.git", .upToNextMajor(from: "1.0.3"))], targets: [ - .target(name: "AsyncAlgorithms"), + .target( + name: "AsyncAlgorithms", + dependencies: [.product(name: "Collections", package: "swift-collections")] + ), .target( name: "AsyncSequenceValidation", dependencies: ["_CAsyncSequenceValidationSupport", "AsyncAlgorithms"]), diff --git a/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md b/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md index 5121c769..19eed41c 100644 --- a/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md +++ b/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md @@ -51,7 +51,7 @@ public final class AsyncThrowingChannel: Asyn } ``` -Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume. +Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume. The calls to `send(:)` and `next()` will immediately resume when their supporting task is cancelled, other operations from other tasks will remain active. ```swift let channel = AsyncChannel() diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index 1719a788..d53a7bac 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -9,6 +9,8 @@ // //===----------------------------------------------------------------------===// +import OrderedCollections + /// A channel for sending elements from one task to another with back pressure. /// /// The `AsyncChannel` class is intended to be used as a communication type between tasks, @@ -86,8 +88,8 @@ public final class AsyncChannel: AsyncSequence, Sendable { enum Emission { case idle - case pending(Set) - case awaiting(Set) + case pending(OrderedSet) + case awaiting(OrderedSet) case finished } @@ -158,7 +160,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { } return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): - nexts.update(with: Awaiting(generation: generation, continuation: continuation)) + nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation)) state.emission = .awaiting(nexts) return nil case .finished: @@ -213,7 +215,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { state.emission = .pending([Pending(generation: generation, continuation: continuation)]) return nil case .pending(var sends): - sends.update(with: Pending(generation: generation, continuation: continuation)) + sends.updateOrAppend(Pending(generation: generation, continuation: continuation)) state.emission = .pending(sends) return nil case .awaiting(var nexts): @@ -235,6 +237,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made /// or when a call to `finish()` is made from another Task. /// If the channel is already finished then this returns immediately + /// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active. public func send(_ element: Element) async { let generation = establish() let sendTokenStatus = ManagedCriticalState(.new) @@ -249,8 +252,8 @@ public final class AsyncChannel: AsyncSequence, Sendable { /// Send a finish to all awaiting iterations. /// All subsequent calls to `next(_:)` will resume immediately. public func finish() { - let (sends, nexts) = state.withCriticalRegion { state -> (Set, Set) in - let result: (Set, Set) + let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet, OrderedSet) in + let result: (OrderedSet, OrderedSet) switch state.emission { case .idle: diff --git a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift index a4a963bc..473482ea 100644 --- a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift @@ -9,6 +9,8 @@ // //===----------------------------------------------------------------------===// +import OrderedCollections + /// An error-throwing channel for sending elements from on task to another with back pressure. /// /// The `AsyncThrowingChannel` class is intended to be used as a communication types between tasks, @@ -95,8 +97,8 @@ public final class AsyncThrowingChannel: Asyn enum Emission { case idle - case pending(Set) - case awaiting(Set) + case pending(OrderedSet) + case awaiting(OrderedSet) case terminated(Termination) } @@ -167,7 +169,7 @@ public final class AsyncThrowingChannel: Asyn } return UnsafeResumption(continuation: send.continuation, success: continuation) case .awaiting(var nexts): - nexts.update(with: Awaiting(generation: generation, continuation: continuation)) + nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation)) state.emission = .awaiting(nexts) return nil case .terminated(let termination): @@ -235,7 +237,7 @@ public final class AsyncThrowingChannel: Asyn state.emission = .pending([Pending(generation: generation, continuation: continuation)]) return nil case .pending(var sends): - sends.update(with: Pending(generation: generation, continuation: continuation)) + sends.updateOrAppend(Pending(generation: generation, continuation: continuation)) state.emission = .pending(sends) return nil case .awaiting(var nexts): @@ -255,7 +257,7 @@ public final class AsyncThrowingChannel: Asyn } func terminateAll(error: Failure? = nil) { - let (sends, nexts) = state.withCriticalRegion { state -> (Set, Set) in + let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet, OrderedSet) in let nextState: Emission if let error = error { @@ -297,6 +299,7 @@ public final class AsyncThrowingChannel: Asyn /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made /// or when a call to `finish()`/`fail(_:)` is made from another Task. /// If the channel is already finished then this returns immediately + /// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active. public func send(_ element: Element) async { let generation = establish() let sendTokenStatus = ManagedCriticalState(.new)