Skip to content

Commit d578c06

Browse files
committed
asyncChannel: harmonize send and next cancellation
1 parent e2920bb commit d578c06

File tree

2 files changed

+137
-101
lines changed

2 files changed

+137
-101
lines changed

Sources/AsyncAlgorithms/AsyncChannel.swift

+125-98
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
3434
guard active else {
3535
return nil
3636
}
37+
3738
let generation = channel.establish()
38-
let value: Element? = await withTaskCancellationHandler { [channel] in
39-
channel.cancel(generation)
39+
let nextTokenStatus = ManagedCriticalState<ChannelTokenStatus>(.new)
40+
41+
let value = await withTaskCancellationHandler { [channel] in
42+
channel.cancelNext(nextTokenStatus, generation)
4043
} operation: {
41-
await channel.next(generation)
44+
await channel.next(nextTokenStatus, generation)
4245
}
43-
44-
if let value = value {
46+
47+
if let value {
4548
return value
4649
} else {
4750
active = false
@@ -56,24 +59,15 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
5659
struct ChannelToken<Continuation>: Hashable {
5760
var generation: Int
5861
var continuation: Continuation?
59-
let cancelled: Bool
6062

6163
init(generation: Int, continuation: Continuation) {
6264
self.generation = generation
6365
self.continuation = continuation
64-
cancelled = false
6566
}
6667

6768
init(placeholder generation: Int) {
6869
self.generation = generation
6970
self.continuation = nil
70-
cancelled = false
71-
}
72-
73-
init(cancelled generation: Int) {
74-
self.generation = generation
75-
self.continuation = nil
76-
cancelled = true
7771
}
7872

7973
func hash(into hasher: inout Hasher) {
@@ -84,37 +78,24 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
8478
return lhs.generation == rhs.generation
8579
}
8680
}
81+
82+
enum ChannelTokenStatus: Equatable {
83+
case new
84+
case cancelled
85+
}
8786

8887
enum Emission {
8988
case idle
90-
case pending([Pending])
89+
case pending(Set<Pending>)
9190
case awaiting(Set<Awaiting>)
92-
93-
mutating func cancel(_ generation: Int) -> UnsafeContinuation<Element?, Never>? {
94-
switch self {
95-
case .awaiting(var awaiting):
96-
let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation
97-
if awaiting.isEmpty {
98-
self = .idle
99-
} else {
100-
self = .awaiting(awaiting)
101-
}
102-
return continuation
103-
case .idle:
104-
self = .awaiting([Awaiting(cancelled: generation)])
105-
return nil
106-
default:
107-
return nil
108-
}
109-
}
11091
}
11192

11293
struct State {
11394
var emission: Emission = .idle
11495
var generation = 0
11596
var terminal = false
11697
}
117-
98+
11899
let state = ManagedCriticalState(State())
119100

120101
/// Create a new `AsyncChannel` given an element type.
@@ -126,18 +107,44 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
126107
return state.generation
127108
}
128109
}
129-
130-
func cancel(_ generation: Int) {
131-
state.withCriticalRegion { state in
132-
state.emission.cancel(generation)
110+
111+
func cancelNext(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
112+
state.withCriticalRegion { state -> UnsafeContinuation<Element?, Never>? in
113+
let continuation: UnsafeContinuation<Element?, Never>?
114+
115+
switch state.emission {
116+
case .awaiting(var nexts):
117+
continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation
118+
if nexts.isEmpty {
119+
state.emission = .idle
120+
} else {
121+
state.emission = .awaiting(nexts)
122+
}
123+
default:
124+
continuation = nil
125+
}
126+
127+
nextTokenStatus.withCriticalRegion { status in
128+
if status == .new {
129+
status = .cancelled
130+
}
131+
}
132+
133+
return continuation
133134
}?.resume(returning: nil)
134135
}
135-
136-
func next(_ generation: Int) async -> Element? {
136+
137+
func next(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) async -> Element? {
137138
return await withUnsafeContinuation { (continuation: UnsafeContinuation<Element?, Never>) in
138139
var cancelled = false
139140
var terminal = false
140141
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
142+
143+
if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
144+
cancelled = true
145+
return nil
146+
}
147+
141148
if state.terminal {
142149
terminal = true
143150
return nil
@@ -155,26 +162,93 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
155162
}
156163
return UnsafeResumption(continuation: send.continuation, success: continuation)
157164
case .awaiting(var nexts):
158-
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
159-
nexts.remove(Awaiting(placeholder: generation))
160-
cancelled = true
161-
}
162-
if nexts.isEmpty {
163-
state.emission = .idle
164-
} else {
165-
state.emission = .awaiting(nexts)
166-
}
165+
nexts.update(with: Awaiting(generation: generation, continuation: continuation))
166+
state.emission = .awaiting(nexts)
167167
return nil
168168
}
169169
}?.resume()
170+
170171
if cancelled || terminal {
171172
continuation.resume(returning: nil)
172173
}
173174
}
174175
}
176+
177+
func cancelSend(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
178+
state.withCriticalRegion { state -> UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>? in
179+
let continuation: UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>?
180+
181+
switch state.emission {
182+
case .pending(var sends):
183+
let send = sends.remove(Pending(placeholder: generation))
184+
if sends.isEmpty {
185+
state.emission = .idle
186+
} else {
187+
state.emission = .pending(sends)
188+
}
189+
continuation = send?.continuation
190+
default:
191+
continuation = nil
192+
}
193+
194+
sendTokenStatus.withCriticalRegion { status in
195+
if status == .new {
196+
status = .cancelled
197+
}
198+
}
199+
200+
return continuation
201+
}?.resume(returning: nil)
202+
}
203+
204+
func send(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int, _ element: Element) async {
205+
let continuation = await withUnsafeContinuation { continuation in
206+
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
207+
208+
if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled || state.terminal {
209+
return UnsafeResumption(continuation: continuation, success: nil)
210+
}
211+
212+
switch state.emission {
213+
case .idle:
214+
state.emission = .pending([Pending(generation: generation, continuation: continuation)])
215+
return nil
216+
case .pending(var sends):
217+
sends.update(with: Pending(generation: generation, continuation: continuation))
218+
state.emission = .pending(sends)
219+
return nil
220+
case .awaiting(var nexts):
221+
let next = nexts.removeFirst().continuation
222+
if nexts.count == 0 {
223+
state.emission = .idle
224+
} else {
225+
state.emission = .awaiting(nexts)
226+
}
227+
return UnsafeResumption(continuation: continuation, success: next)
228+
}
229+
}?.resume()
230+
}
231+
continuation?.resume(returning: element)
232+
}
233+
234+
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
235+
/// or when a call to `finish()` is made from another Task.
236+
/// If the channel is already finished then this returns immediately
237+
public func send(_ element: Element) async {
238+
let generation = establish()
239+
let sendTokenStatus = ManagedCriticalState<ChannelTokenStatus>(.new)
240+
241+
await withTaskCancellationHandler { [weak self] in
242+
self?.cancelSend(sendTokenStatus, generation)
243+
} operation: {
244+
await send(sendTokenStatus, generation, element)
245+
}
246+
}
175247

176-
func terminateAll() {
177-
let (sends, nexts) = state.withCriticalRegion { state -> ([Pending], Set<Awaiting>) in
248+
/// Send a finish to all awaiting iterations.
249+
/// All subsequent calls to `next(_:)` will resume immediately.
250+
public func finish() {
251+
let (sends, nexts) = state.withCriticalRegion { state -> (Set<Pending>, Set<Awaiting>) in
178252
if state.terminal {
179253
return ([], [])
180254
}
@@ -198,53 +272,6 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
198272
}
199273
}
200274

201-
func _send(_ element: Element) async {
202-
let generation = establish()
203-
204-
await withTaskCancellationHandler {
205-
terminateAll()
206-
} operation: {
207-
let continuation: UnsafeContinuation<Element?, Never>? = await withUnsafeContinuation { continuation in
208-
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
209-
if state.terminal {
210-
return UnsafeResumption(continuation: continuation, success: nil)
211-
}
212-
switch state.emission {
213-
case .idle:
214-
state.emission = .pending([Pending(generation: generation, continuation: continuation)])
215-
return nil
216-
case .pending(var sends):
217-
sends.append(Pending(generation: generation, continuation: continuation))
218-
state.emission = .pending(sends)
219-
return nil
220-
case .awaiting(var nexts):
221-
let next = nexts.removeFirst().continuation
222-
if nexts.count == 0 {
223-
state.emission = .idle
224-
} else {
225-
state.emission = .awaiting(nexts)
226-
}
227-
return UnsafeResumption(continuation: continuation, success: next)
228-
}
229-
}?.resume()
230-
}
231-
continuation?.resume(returning: element)
232-
}
233-
}
234-
235-
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
236-
/// or when a call to `finish()` is made from another Task.
237-
/// If the channel is already finished then this returns immediately
238-
public func send(_ element: Element) async {
239-
await _send(element)
240-
}
241-
242-
/// Send a finish to all awaiting iterations.
243-
/// All subsequent calls to `next(_:)` will resume immediately.
244-
public func finish() {
245-
terminateAll()
246-
}
247-
248275
/// Create an `Iterator` for iteration of an `AsyncChannel`
249276
public func makeAsyncIterator() -> Iterator {
250277
return Iterator(self)

Tests/AsyncAlgorithmsTests/TestChannel.swift

+12-3
Original file line numberDiff line numberDiff line change
@@ -227,19 +227,28 @@ final class TestChannel: XCTestCase {
227227
XCTAssertNil(value)
228228
}
229229

230-
func test_asyncChannel_resumes_send_when_task_is_cancelled() async {
230+
func test_asyncChannel_resumes_send_when_task_is_cancelled_and_continue_remaining_send_tasks() async {
231231
let channel = AsyncChannel<Int>()
232232
let notYetDone = expectation(description: "not yet done")
233233
notYetDone.isInverted = true
234234
let done = expectation(description: "done")
235-
let task = Task {
235+
let task1 = Task {
236236
await channel.send(1)
237237
notYetDone.fulfill()
238238
done.fulfill()
239239
}
240+
241+
Task {
242+
await channel.send(2)
243+
}
244+
240245
wait(for: [notYetDone], timeout: 0.1)
241-
task.cancel()
246+
task1.cancel()
242247
wait(for: [done], timeout: 1.0)
248+
249+
var iterator = channel.makeAsyncIterator()
250+
let received = await iterator.next()
251+
XCTAssertEqual(received, 2)
243252
}
244253

245254
func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled() async {

0 commit comments

Comments
 (0)