Skip to content

Commit f368bed

Browse files
Thibault Wittembergtwittemb
authored andcommitted
asyncThrowingChannel: enforce termination for all when finished
1 parent ee69df1 commit f368bed

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

Sources/AsyncAlgorithms/AsyncThrowingChannel.swift

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
129129
func next(_ generation: Int) async throws -> Element? {
130130
return try await withUnsafeThrowingContinuation { continuation in
131131
var cancelled = false
132+
var isTerminal = false
132133
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
134+
if state.terminal {
135+
isTerminal = true
136+
return nil
137+
}
133138
switch state.emission {
134139
case .idle:
135140
state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)])
@@ -155,13 +160,13 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
155160
return nil
156161
}
157162
}?.resume()
158-
if cancelled {
163+
if cancelled || isTerminal {
159164
continuation.resume(returning: nil)
160165
}
161166
}
162167
}
163168

164-
func cancelSend() {
169+
func finishAll() {
165170
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>], Set<Awaiting>) in
166171
if state.terminal {
167172
return ([], [])
@@ -186,23 +191,20 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
186191
}
187192
}
188193

189-
func _send(_ result: Result<Element?, Error>) async {
194+
func _send(_ result: Result<Element, Error>) async {
190195
await withTaskCancellationHandler {
191-
cancelSend()
196+
finishAll()
192197
} operation: {
193198
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
194199
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
195200
if state.terminal {
196201
return UnsafeResumption(continuation: continuation, success: nil)
197202
}
198-
switch result {
199-
case .success(let value):
200-
if value == nil {
201-
state.terminal = true
202-
}
203-
case .failure:
203+
204+
if case .failure = result {
204205
state.terminal = true
205206
}
207+
206208
switch state.emission {
207209
case .idle:
208210
state.emission = .pending([continuation])
@@ -222,7 +224,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
222224
}
223225
}?.resume()
224226
}
225-
continuation?.resume(with: result)
227+
continuation?.resume(with: result.map { $0 as Element? })
226228
}
227229
}
228230

@@ -238,10 +240,9 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
238240
await _send(.failure(error))
239241
}
240242

241-
/// Send a finish to an awaiting iteration. This function will resume when the next call to `next()` is made.
242-
/// If the channel is already finished then this returns immediately
243-
public func finish() async {
244-
await _send(.success(nil))
243+
/// Send a finish to all awaiting iterations.
244+
public func finish() {
245+
finishAll()
245246
}
246247

247248
public func makeAsyncIterator() -> Iterator {

0 commit comments

Comments
 (0)