@@ -34,14 +34,17 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
34
34
guard active else {
35
35
return nil
36
36
}
37
+
37
38
let generation = channel. establish ( )
38
- let value : Element ? = await withTaskCancellationHandler { [ channel] in
39
- channel. cancel ( generation)
39
+ let nextTokenStatus = ManagedCriticalState < ChannelTokenStatus > ( . new)
40
+
41
+ let value = await withTaskCancellationHandler { [ channel] in
42
+ channel. cancelNext ( nextTokenStatus, generation)
40
43
} operation: {
41
- await channel. next ( generation)
44
+ await channel. next ( nextTokenStatus , generation)
42
45
}
43
-
44
- if let value = value {
46
+
47
+ if let value {
45
48
return value
46
49
} else {
47
50
active = false
@@ -56,24 +59,15 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
56
59
struct ChannelToken < Continuation> : Hashable {
57
60
var generation : Int
58
61
var continuation : Continuation ?
59
- let cancelled : Bool
60
62
61
63
init ( generation: Int , continuation: Continuation ) {
62
64
self . generation = generation
63
65
self . continuation = continuation
64
- cancelled = false
65
66
}
66
67
67
68
init ( placeholder generation: Int ) {
68
69
self . generation = generation
69
70
self . continuation = nil
70
- cancelled = false
71
- }
72
-
73
- init ( cancelled generation: Int ) {
74
- self . generation = generation
75
- self . continuation = nil
76
- cancelled = true
77
71
}
78
72
79
73
func hash( into hasher: inout Hasher ) {
@@ -84,37 +78,24 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
84
78
return lhs. generation == rhs. generation
85
79
}
86
80
}
81
+
82
+ enum ChannelTokenStatus : Equatable {
83
+ case new
84
+ case cancelled
85
+ }
87
86
88
87
enum Emission {
89
88
case idle
90
- case pending( [ Pending ] )
89
+ case pending( Set < Pending > )
91
90
case awaiting( Set < Awaiting > )
92
-
93
- mutating func cancel( _ generation: Int ) -> UnsafeContinuation < Element ? , Never > ? {
94
- switch self {
95
- case . awaiting( var awaiting) :
96
- let continuation = awaiting. remove ( Awaiting ( placeholder: generation) ) ? . continuation
97
- if awaiting. isEmpty {
98
- self = . idle
99
- } else {
100
- self = . awaiting( awaiting)
101
- }
102
- return continuation
103
- case . idle:
104
- self = . awaiting( [ Awaiting ( cancelled: generation) ] )
105
- return nil
106
- default :
107
- return nil
108
- }
109
- }
110
91
}
111
92
112
93
struct State {
113
94
var emission : Emission = . idle
114
95
var generation = 0
115
96
var terminal = false
116
97
}
117
-
98
+
118
99
let state = ManagedCriticalState ( State ( ) )
119
100
120
101
/// Create a new `AsyncChannel` given an element type.
@@ -126,18 +107,44 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
126
107
return state. generation
127
108
}
128
109
}
129
-
130
- func cancel( _ generation: Int ) {
131
- state. withCriticalRegion { state in
132
- state. emission. cancel ( generation)
110
+
111
+ func cancelNext( _ nextTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) {
112
+ state. withCriticalRegion { state -> UnsafeContinuation < Element ? , Never > ? in
113
+ let continuation : UnsafeContinuation < Element ? , Never > ?
114
+
115
+ switch state. emission {
116
+ case . awaiting( var nexts) :
117
+ continuation = nexts. remove ( Awaiting ( placeholder: generation) ) ? . continuation
118
+ if nexts. isEmpty {
119
+ state. emission = . idle
120
+ } else {
121
+ state. emission = . awaiting( nexts)
122
+ }
123
+ default :
124
+ continuation = nil
125
+ }
126
+
127
+ nextTokenStatus. withCriticalRegion { status in
128
+ if status == . new {
129
+ status = . cancelled
130
+ }
131
+ }
132
+
133
+ return continuation
133
134
} ? . resume ( returning: nil )
134
135
}
135
-
136
- func next( _ generation: Int ) async -> Element ? {
136
+
137
+ func next( _ nextTokenStatus : ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) async -> Element ? {
137
138
return await withUnsafeContinuation { ( continuation: UnsafeContinuation < Element ? , Never > ) in
138
139
var cancelled = false
139
140
var terminal = false
140
141
state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Never > ? , Never > ? in
142
+
143
+ if nextTokenStatus. withCriticalRegion ( { $0 } ) == . cancelled {
144
+ cancelled = true
145
+ return nil
146
+ }
147
+
141
148
if state. terminal {
142
149
terminal = true
143
150
return nil
@@ -155,26 +162,93 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
155
162
}
156
163
return UnsafeResumption ( continuation: send. continuation, success: continuation)
157
164
case . awaiting( var nexts) :
158
- if nexts. update ( with: Awaiting ( generation: generation, continuation: continuation) ) != nil {
159
- nexts. remove ( Awaiting ( placeholder: generation) )
160
- cancelled = true
161
- }
162
- if nexts. isEmpty {
163
- state. emission = . idle
164
- } else {
165
- state. emission = . awaiting( nexts)
166
- }
165
+ nexts. update ( with: Awaiting ( generation: generation, continuation: continuation) )
166
+ state. emission = . awaiting( nexts)
167
167
return nil
168
168
}
169
169
} ? . resume ( )
170
+
170
171
if cancelled || terminal {
171
172
continuation. resume ( returning: nil )
172
173
}
173
174
}
174
175
}
176
+
177
+ func cancelSend( _ sendTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int ) {
178
+ state. withCriticalRegion { state -> UnsafeContinuation < UnsafeContinuation < Element ? , Never > ? , Never > ? in
179
+ let continuation : UnsafeContinuation < UnsafeContinuation < Element ? , Never > ? , Never > ?
180
+
181
+ switch state. emission {
182
+ case . pending( var sends) :
183
+ let send = sends. remove ( Pending ( placeholder: generation) )
184
+ if sends. isEmpty {
185
+ state. emission = . idle
186
+ } else {
187
+ state. emission = . pending( sends)
188
+ }
189
+ continuation = send? . continuation
190
+ default :
191
+ continuation = nil
192
+ }
193
+
194
+ sendTokenStatus. withCriticalRegion { status in
195
+ if status == . new {
196
+ status = . cancelled
197
+ }
198
+ }
199
+
200
+ return continuation
201
+ } ? . resume ( returning: nil )
202
+ }
203
+
204
+ func send( _ sendTokenStatus: ManagedCriticalState < ChannelTokenStatus > , _ generation: Int , _ element: Element ) async {
205
+ let continuation = await withUnsafeContinuation { continuation in
206
+ state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Never > ? , Never > ? in
207
+
208
+ if sendTokenStatus. withCriticalRegion ( { $0 } ) == . cancelled || state. terminal {
209
+ return UnsafeResumption ( continuation: continuation, success: nil )
210
+ }
211
+
212
+ switch state. emission {
213
+ case . idle:
214
+ state. emission = . pending( [ Pending ( generation: generation, continuation: continuation) ] )
215
+ return nil
216
+ case . pending( var sends) :
217
+ sends. update ( with: Pending ( generation: generation, continuation: continuation) )
218
+ state. emission = . pending( sends)
219
+ return nil
220
+ case . awaiting( var nexts) :
221
+ let next = nexts. removeFirst ( ) . continuation
222
+ if nexts. count == 0 {
223
+ state. emission = . idle
224
+ } else {
225
+ state. emission = . awaiting( nexts)
226
+ }
227
+ return UnsafeResumption ( continuation: continuation, success: next)
228
+ }
229
+ } ? . resume ( )
230
+ }
231
+ continuation? . resume ( returning: element)
232
+ }
233
+
234
+ /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
235
+ /// or when a call to `finish()` is made from another Task.
236
+ /// If the channel is already finished then this returns immediately
237
+ public func send( _ element: Element ) async {
238
+ let generation = establish ( )
239
+ let sendTokenStatus = ManagedCriticalState < ChannelTokenStatus > ( . new)
240
+
241
+ await withTaskCancellationHandler { [ weak self] in
242
+ self ? . cancelSend ( sendTokenStatus, generation)
243
+ } operation: {
244
+ await send ( sendTokenStatus, generation, element)
245
+ }
246
+ }
175
247
176
- func terminateAll( ) {
177
- let ( sends, nexts) = state. withCriticalRegion { state -> ( [ Pending ] , Set < Awaiting > ) in
248
+ /// Send a finish to all awaiting iterations.
249
+ /// All subsequent calls to `next(_:)` will resume immediately.
250
+ public func finish( ) {
251
+ let ( sends, nexts) = state. withCriticalRegion { state -> ( Set < Pending > , Set < Awaiting > ) in
178
252
if state. terminal {
179
253
return ( [ ] , [ ] )
180
254
}
@@ -198,53 +272,6 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
198
272
}
199
273
}
200
274
201
- func _send( _ element: Element ) async {
202
- let generation = establish ( )
203
-
204
- await withTaskCancellationHandler {
205
- terminateAll ( )
206
- } operation: {
207
- let continuation : UnsafeContinuation < Element ? , Never > ? = await withUnsafeContinuation { continuation in
208
- state. withCriticalRegion { state -> UnsafeResumption < UnsafeContinuation < Element ? , Never > ? , Never > ? in
209
- if state. terminal {
210
- return UnsafeResumption ( continuation: continuation, success: nil )
211
- }
212
- switch state. emission {
213
- case . idle:
214
- state. emission = . pending( [ Pending ( generation: generation, continuation: continuation) ] )
215
- return nil
216
- case . pending( var sends) :
217
- sends. append ( Pending ( generation: generation, continuation: continuation) )
218
- state. emission = . pending( sends)
219
- return nil
220
- case . awaiting( var nexts) :
221
- let next = nexts. removeFirst ( ) . continuation
222
- if nexts. count == 0 {
223
- state. emission = . idle
224
- } else {
225
- state. emission = . awaiting( nexts)
226
- }
227
- return UnsafeResumption ( continuation: continuation, success: next)
228
- }
229
- } ? . resume ( )
230
- }
231
- continuation? . resume ( returning: element)
232
- }
233
- }
234
-
235
- /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
236
- /// or when a call to `finish()` is made from another Task.
237
- /// If the channel is already finished then this returns immediately
238
- public func send( _ element: Element ) async {
239
- await _send ( element)
240
- }
241
-
242
- /// Send a finish to all awaiting iterations.
243
- /// All subsequent calls to `next(_:)` will resume immediately.
244
- public func finish( ) {
245
- terminateAll ( )
246
- }
247
-
248
275
/// Create an `Iterator` for iteration of an `AsyncChannel`
249
276
public func makeAsyncIterator( ) -> Iterator {
250
277
return Iterator ( self )
0 commit comments