Skip to content

Commit 028546f

Browse files
fixup: Move tasks out of enum assoc data to avoid race
Signed-off-by: Si Beaumont <[email protected]>
1 parent 42e0703 commit 028546f

File tree

2 files changed

+52
-41
lines changed

2 files changed

+52
-41
lines changed

Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift

+17-19
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,15 @@ internal final class AsyncServerHandler<
195195

196196
/// The task used to run the async user function.
197197
///
198-
/// - TODO: Should this be part of the associated metadata in the `State` enum? Doing so would
199-
/// make testing a bit cumbersome since a bunch of tests await this task finishing. Shoving it in
200-
/// the enum.
198+
/// - TODO: I'd like it if this was part of the assoc data for the .active state but doing so may introduce a race condition.
201199
@usableFromInline
202-
internal var task: Task<Void, Never>? = nil
200+
internal var userHandlerTask: Task<Void, Never>? = nil
201+
202+
/// The task used to drain the response stream writer after the user function has completed.
203+
///
204+
/// - TODO: I'd like it if this was part of the assoc data for the .finishingSuccessfully state but doing so may introduce a race condition.
205+
@usableFromInline
206+
internal var responseStreamDrainTask: Task<Void, Never>? = nil
203207

204208
@usableFromInline
205209
internal enum State {
@@ -250,9 +254,6 @@ internal final class AsyncServerHandler<
250254
/// handler has completed that have yet to be written. While in this state we will await the
251255
/// writer flushing its responses before sending `.end` to the stream.
252256
///
253-
/// - The `Task` is used to drain the response stream writer. It is stashed here so that we can
254-
/// cancel it in the event of an error.
255-
///
256257
/// - The `EventLoopPromise` bridges the NIO and async-await worlds. It is the mechanism that we
257258
/// use to run a callback when the response stream has been flushed. It is fulfilled with the
258259
/// result of the async `Task` executing the user handler using `completeWithTask(_:)`.
@@ -264,7 +265,6 @@ internal final class AsyncServerHandler<
264265
case finishingSuccessfully(
265266
GRPCAsyncServerCallContext,
266267
GRPCAsyncResponseStreamWriter<Response>,
267-
Task<Void, Never>,
268268
EventLoopPromise<GRPCStatus>
269269
)
270270

@@ -341,10 +341,10 @@ internal final class AsyncServerHandler<
341341
self.state = .completed
342342

343343
case .active:
344-
self.task?.cancel()
344+
self.userHandlerTask?.cancel()
345345

346-
case let .finishingSuccessfully(_, _, task, _):
347-
task.cancel()
346+
case .finishingSuccessfully:
347+
self.responseStreamDrainTask?.cancel()
348348

349349
case .completed:
350350
self.interceptors = nil
@@ -404,7 +404,7 @@ internal final class AsyncServerHandler<
404404
self.interceptors.send(.metadata([:]), promise: nil)
405405

406406
// Spin up a task to call the async user handler.
407-
self.task = userHandlerPromise.completeWithTask {
407+
self.userHandlerTask = userHandlerPromise.completeWithTask {
408408
try await withTaskCancellationHandler {
409409
do {
410410
// Call the user function.
@@ -545,7 +545,7 @@ internal final class AsyncServerHandler<
545545
// Register callback for the response stream being drained.
546546
responseStreamDrainedPromise.futureResult.whenComplete(self.responseStreamDrained(_:))
547547

548-
let responseStreamDrainTask = responseStreamDrainedPromise.completeWithTask {
548+
self.responseStreamDrainTask = responseStreamDrainedPromise.completeWithTask {
549549
try await withTaskCancellationHandler {
550550
// Await the writer finish.
551551
try await responseStreamWriter._asyncWriter.finish(())
@@ -561,7 +561,6 @@ internal final class AsyncServerHandler<
561561
self.state = .finishingSuccessfully(
562562
context,
563563
responseStreamWriter,
564-
responseStreamDrainTask,
565564
responseStreamDrainedPromise
566565
)
567566

@@ -587,7 +586,7 @@ internal final class AsyncServerHandler<
587586
case .active:
588587
preconditionFailure()
589588

590-
case let .finishingSuccessfully(context, _, _, _):
589+
case let .finishingSuccessfully(context, _, _):
591590
switch result {
592591
case let .success(status):
593592
// Now we have drained the response stream writer from the user handler we can send end.
@@ -643,11 +642,10 @@ internal final class AsyncServerHandler<
643642
// which it is reading and give the user handler an opportunity to cleanup.
644643
//
645644
// NOTE: This line used to be before we explicitly fail the status promise but it was exaserbating a race condition and causing crashes. See https://bugs.swift.org/browse/SR-15108.
646-
self.task?.cancel()
645+
self.userHandlerTask?.cancel()
647646

648-
case let .finishingSuccessfully(_, _, responseStreamWriterDrainTask, _):
649-
self.task?.cancel()
650-
responseStreamWriterDrainTask.cancel()
647+
case .finishingSuccessfully(_, _, _):
648+
self.responseStreamDrainTask?.cancel()
651649

652650
case .completed:
653651
()

Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift

+35-22
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
8181
handler.receiveMessage(ByteBuffer(string: "3"))
8282
handler.receiveEnd()
8383

84-
// Wait for user handler to finish.
85-
await handler.task?.value
84+
// Wait for tasks to finish.
85+
await handler.userHandlerTask?.value
86+
await handler.responseStreamDrainTask?.value
8687

8788
handler.finish()
8889

@@ -107,8 +108,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
107108
handler.receiveMessage(ByteBuffer(string: "3"))
108109
handler.receiveEnd()
109110

110-
// Wait for user handler to finish.
111-
await handler.task?.value
111+
// Wait for tasks to finish.
112+
await handler.userHandlerTask?.value
113+
await handler.responseStreamDrainTask?.value
112114

113115
await assertThat(
114116
self.recorder.messages,
@@ -135,8 +137,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
135137
handler.receiveMessage(ByteBuffer(string: "3"))
136138
handler.receiveEnd()
137139

138-
// Wait for user handler to finish.
139-
await handler.task?.value
140+
// Wait for tasks to finish.
141+
await handler.userHandlerTask?.value
142+
await handler.responseStreamDrainTask?.value
140143

141144
await assertThat(
142145
self.recorder.messages,
@@ -148,11 +151,11 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
148151
func testTaskOnlyCreatedAfterHeaders() { XCTAsyncTest {
149152
let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:))
150153

151-
await assertThat(handler.task, .is(.nil()))
154+
await assertThat(handler.userHandlerTask, .is(.nil()))
152155

153156
handler.receiveMetadata([:])
154157

155-
await assertThat(handler.task, .is(.notNil()))
158+
await assertThat(handler.userHandlerTask, .is(.notNil()))
156159
} }
157160

158161
func testThrowingDeserializer() { XCTAsyncTest {
@@ -174,8 +177,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
174177
let buffer = ByteBuffer(string: "hello")
175178
handler.receiveMessage(buffer)
176179

177-
// Wait for user handler to finish.
178-
await handler.task?.value
180+
// Wait for tasks to finish.
181+
await handler.userHandlerTask?.value
182+
await handler.responseStreamDrainTask?.value
179183

180184
await assertThat(self.recorder.messages, .isEmpty())
181185
await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
@@ -197,8 +201,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
197201
handler.receiveMessage(buffer)
198202
handler.receiveEnd()
199203

200-
// Wait for user handler to finish.
201-
await handler.task?.value
204+
// Wait for tasks to finish.
205+
await handler.userHandlerTask?.value
206+
await handler.responseStreamDrainTask?.value
202207

203208
await assertThat(self.recorder.messages, .isEmpty())
204209
await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
@@ -210,7 +215,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
210215

211216
handler.receiveMessage(ByteBuffer(string: "foo"))
212217

213-
await handler.task?.value
218+
// Wait for tasks to finish.
219+
await handler.userHandlerTask?.value
220+
await handler.responseStreamDrainTask?.value
214221

215222
await assertThat(self.recorder.metadata, .is(.nil()))
216223
await assertThat(self.recorder.messages, .isEmpty())
@@ -231,8 +238,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
231238

232239
handler.receiveMetadata([:])
233240

234-
// Wait for user handler to finish.
235-
await handler.task?.value
241+
// Wait for tasks to finish.
242+
await handler.userHandlerTask?.value
243+
await handler.responseStreamDrainTask?.value
236244

237245
await assertThat(self.recorder.messages, .isEmpty())
238246
await assertThat(self.recorder.status, .notNil(.hasCode(.internalError)))
@@ -260,8 +268,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
260268

261269
handler.finish()
262270

263-
// Wait for user handler to finish.
264-
await handler.task?.value
271+
// Wait for tasks to finish.
272+
await handler.userHandlerTask?.value
273+
await handler.responseStreamDrainTask?.value
265274

266275
await assertThat(self.recorder.messages, .isEmpty())
267276
await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
@@ -279,8 +288,9 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
279288

280289
handler.finish()
281290

282-
// Wait for user handler to finish.
283-
await handler.task?.value
291+
// Wait for tasks to finish.
292+
await handler.userHandlerTask?.value
293+
await handler.responseStreamDrainTask?.value
284294

285295
await assertThat(self.recorder.messages.first, .is(ByteBuffer(string: "hello")))
286296
await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable)))
@@ -297,7 +307,7 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
297307
handler.receiveMetadata([:])
298308

299309
// Wait for user handler to finish (it's gonna throw immediately).
300-
await assertThat(await handler.task?.value, .notNil())
310+
await assertThat(await handler.userHandlerTask?.value, .notNil())
301311

302312
// Check the status is `.unknown`.
303313
await assertThat(self.recorder.status, .notNil(.hasCode(.unknown)))
@@ -316,19 +326,22 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase {
316326

317327
// Send two requests and end, pausing the writer in the middle.
318328
switch handler.state {
319-
case let .active(_, _, responseStreamWriter, _):
329+
case let .active(_, _, responseStreamWriter, promise):
320330
handler.receiveMessage(ByteBuffer(string: "diaz"))
321331
await responseStreamWriter._asyncWriter.toggleWritability()
322332
handler.receiveMessage(ByteBuffer(string: "santiago"))
323333
handler.receiveEnd()
324334
await responseStreamWriter._asyncWriter.toggleWritability()
325-
await handler.task?.value
335+
await handler.userHandlerTask?.value
336+
_ = try await promise.futureResult.get()
326337
default:
327338
XCTFail("Unexpected handler state: \(handler.state)")
328339
}
329340

330341
handler.finish()
331342

343+
await assertThat(handler.responseStreamDrainTask, .notNil())
344+
332345
await assertThat(self.recorder.messages, .is([
333346
ByteBuffer(string: "diaz"),
334347
ByteBuffer(string: "santiago"),

0 commit comments

Comments
 (0)