Skip to content

[Channel] improve send cancellation #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ let package = Package(
.library(name: "_CAsyncSequenceValidationSupport", type: .static, targets: ["AsyncSequenceValidation"]),
.library(name: "AsyncAlgorithms_XCTest", targets: ["AsyncAlgorithms_XCTest"]),
],
dependencies: [],
dependencies: [.package(url: "https://github.com/apple/swift-collections.git", .upToNextMajor(from: "1.0.3"))],
targets: [
.target(name: "AsyncAlgorithms"),
.target(
name: "AsyncAlgorithms",
dependencies: [.product(name: "Collections", package: "swift-collections")]
),
.target(
name: "AsyncSequenceValidation",
dependencies: ["_CAsyncSequenceValidationSupport", "AsyncAlgorithms"]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
}
```

Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume.
Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume. The calls to `send(:)` and `next()` will immediately resume when their supporting task is cancelled, other operations from other tasks will remain active.

```swift
let channel = AsyncChannel<String>()
Expand Down
272 changes: 155 additions & 117 deletions Sources/AsyncAlgorithms/AsyncChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
//
//===----------------------------------------------------------------------===//

import OrderedCollections

/// A channel for sending elements from one task to another with back pressure.
///
/// The `AsyncChannel` class is intended to be used as a communication type between tasks,
Expand All @@ -34,14 +36,17 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
guard active else {
return nil
}

let generation = channel.establish()
let value: Element? = await withTaskCancellationHandler { [channel] in
channel.cancel(generation)
let nextTokenStatus = ManagedCriticalState<ChannelTokenStatus>(.new)

let value = await withTaskCancellationHandler { [channel] in
channel.cancelNext(nextTokenStatus, generation)
} operation: {
await channel.next(generation)
await channel.next(nextTokenStatus, generation)
}
if let value = value {

if let value {
return value
} else {
active = false
Expand All @@ -50,68 +55,49 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
}
}

struct Awaiting: Hashable {
typealias Pending = ChannelToken<UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>>
typealias Awaiting = ChannelToken<UnsafeContinuation<Element?, Never>>

struct ChannelToken<Continuation>: Hashable {
var generation: Int
var continuation: UnsafeContinuation<Element?, Never>?
let cancelled: Bool

init(generation: Int, continuation: UnsafeContinuation<Element?, Never>) {
var continuation: Continuation?

init(generation: Int, continuation: Continuation) {
self.generation = generation
self.continuation = continuation
cancelled = false
}

init(placeholder generation: Int) {
self.generation = generation
self.continuation = nil
cancelled = false
}

init(cancelled generation: Int) {
self.generation = generation
self.continuation = nil
cancelled = true
}


func hash(into hasher: inout Hasher) {
hasher.combine(generation)
}
static func == (_ lhs: Awaiting, _ rhs: Awaiting) -> Bool {

static func == (_ lhs: ChannelToken, _ rhs: ChannelToken) -> Bool {
return lhs.generation == rhs.generation
}
}

enum ChannelTokenStatus: Equatable {
case new
case cancelled
}

enum Emission {
case idle
case pending([UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>])
case awaiting(Set<Awaiting>)

mutating func cancel(_ generation: Int) -> UnsafeContinuation<Element?, Never>? {
switch self {
case .awaiting(var awaiting):
let continuation = awaiting.remove(Awaiting(placeholder: generation))?.continuation
if awaiting.isEmpty {
self = .idle
} else {
self = .awaiting(awaiting)
}
return continuation
case .idle:
self = .awaiting([Awaiting(cancelled: generation)])
return nil
default:
return nil
}
}
case pending(OrderedSet<Pending>)
case awaiting(OrderedSet<Awaiting>)
case finished
}

struct State {
var emission: Emission = .idle
var generation = 0
var terminal = false
}

let state = ManagedCriticalState(State())

/// Create a new `AsyncChannel` given an element type.
Expand All @@ -123,22 +109,44 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
return state.generation
}
}

func cancel(_ generation: Int) {
state.withCriticalRegion { state in
state.emission.cancel(generation)

func cancelNext(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
state.withCriticalRegion { state -> UnsafeContinuation<Element?, Never>? in
let continuation: UnsafeContinuation<Element?, Never>?

switch state.emission {
case .awaiting(var nexts):
continuation = nexts.remove(Awaiting(placeholder: generation))?.continuation
if nexts.isEmpty {
state.emission = .idle
} else {
state.emission = .awaiting(nexts)
}
default:
continuation = nil
}

nextTokenStatus.withCriticalRegion { status in
if status == .new {
status = .cancelled
}
}

return continuation
}?.resume(returning: nil)
}
func next(_ generation: Int) async -> Element? {
return await withUnsafeContinuation { continuation in

func next(_ nextTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) async -> Element? {
return await withUnsafeContinuation { (continuation: UnsafeContinuation<Element?, Never>) in
var cancelled = false
var terminal = false
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
if state.terminal {
terminal = true

if nextTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
cancelled = true
return nil
}

switch state.emission {
case .idle:
state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)])
Expand All @@ -150,94 +158,124 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
} else {
state.emission = .pending(sends)
}
return UnsafeResumption(continuation: send, success: continuation)
return UnsafeResumption(continuation: send.continuation, success: continuation)
case .awaiting(var nexts):
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
nexts.remove(Awaiting(placeholder: generation))
cancelled = true
}
if nexts.isEmpty {
state.emission = .idle
} else {
state.emission = .awaiting(nexts)
}
nexts.updateOrAppend(Awaiting(generation: generation, continuation: continuation))
state.emission = .awaiting(nexts)
return nil
case .finished:
terminal = true
return nil
}
}?.resume()

if cancelled || terminal {
continuation.resume(returning: nil)
}
}
}

func terminateAll() {
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>], Set<Awaiting>) in
if state.terminal {
return ([], [])
}
state.terminal = true

func cancelSend(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int) {
state.withCriticalRegion { state -> UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>? in
let continuation: UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>?

switch state.emission {
case .idle:
return ([], [])
case .pending(let nexts):
state.emission = .idle
return (nexts, [])
case .awaiting(let nexts):
state.emission = .idle
return ([], nexts)
case .pending(var sends):
let send = sends.remove(Pending(placeholder: generation))
if sends.isEmpty {
state.emission = .idle
} else {
state.emission = .pending(sends)
}
continuation = send?.continuation
default:
continuation = nil
}
}
for send in sends {
send.resume(returning: nil)
}
for next in nexts {
next.continuation?.resume(returning: nil)
}

sendTokenStatus.withCriticalRegion { status in
if status == .new {
status = .cancelled
}
}

return continuation
}?.resume(returning: nil)
}

func _send(_ element: Element) async {
await withTaskCancellationHandler {
terminateAll()
} operation: {
let continuation: UnsafeContinuation<Element?, Never>? = await withUnsafeContinuation { continuation in
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
if state.terminal {
return UnsafeResumption(continuation: continuation, success: nil)
}
switch state.emission {
case .idle:
state.emission = .pending([continuation])
return nil
case .pending(var sends):
sends.append(continuation)
state.emission = .pending(sends)
return nil
case .awaiting(var nexts):
let next = nexts.removeFirst().continuation
if nexts.count == 0 {
state.emission = .idle
} else {
state.emission = .awaiting(nexts)
}
return UnsafeResumption(continuation: continuation, success: next)

func send(_ sendTokenStatus: ManagedCriticalState<ChannelTokenStatus>, _ generation: Int, _ element: Element) async {
let continuation = await withUnsafeContinuation { continuation in
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in

if sendTokenStatus.withCriticalRegion({ $0 }) == .cancelled {
return UnsafeResumption(continuation: continuation, success: nil)
}

switch state.emission {
case .idle:
state.emission = .pending([Pending(generation: generation, continuation: continuation)])
return nil
case .pending(var sends):
sends.updateOrAppend(Pending(generation: generation, continuation: continuation))
state.emission = .pending(sends)
return nil
case .awaiting(var nexts):
let next = nexts.removeFirst().continuation
if nexts.count == 0 {
state.emission = .idle
} else {
state.emission = .awaiting(nexts)
}
}?.resume()
}
continuation?.resume(returning: element)
return UnsafeResumption(continuation: continuation, success: next)
case .finished:
return UnsafeResumption(continuation: continuation, success: nil)
}
}?.resume()
}
continuation?.resume(returning: element)
}

/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
/// or when a call to `finish()` is made from another Task.
/// If the channel is already finished then this returns immediately
/// If the task is cancelled, this function will resume. Other sending operations from other tasks will remain active.
public func send(_ element: Element) async {
await _send(element)
let generation = establish()
let sendTokenStatus = ManagedCriticalState<ChannelTokenStatus>(.new)

await withTaskCancellationHandler { [weak self] in
self?.cancelSend(sendTokenStatus, generation)
} operation: {
await send(sendTokenStatus, generation, element)
}
}

/// Send a finish to all awaiting iterations.
/// All subsequent calls to `next(_:)` will resume immediately.
public func finish() {
terminateAll()
let (sends, nexts) = state.withCriticalRegion { state -> (OrderedSet<Pending>, OrderedSet<Awaiting>) in
let result: (OrderedSet<Pending>, OrderedSet<Awaiting>)

switch state.emission {
case .idle:
result = ([], [])
case .pending(let nexts):
result = (nexts, [])
case .awaiting(let nexts):
result = ([], nexts)
case .finished:
result = ([], [])
}

state.emission = .finished

return result
}
for send in sends {
send.continuation?.resume(returning: nil)
}
for next in nexts {
next.continuation?.resume(returning: nil)
}
}

/// Create an `Iterator` for iteration of an `AsyncChannel`
Expand Down
Loading