@@ -204,38 +204,56 @@ internal final class AsyncServerHandler<
204204 /// No headers have been received.
205205 case idle
206206
207- /// Headers have been received, and an async `Task` has been created to execute the user
208- /// handler.
209- ///
210- /// The inputs to the user handler are held in the associated data of this enum value:
211- ///
212- /// - The `PassthroughMessageSource` is the source backing the request stream that is being
213- /// consumed by the user handler.
214- ///
215- /// - The `GRPCAsyncServerContext` is a reference to the context that was passed to the user
216- /// handler.
217- ///
218- /// - The `GRPCAsyncResponseStreamWriter` is the response stream writer that is being written to
219- /// by the user handler. Because this is pausable, it may contain responses after the user
220- /// handler has completed that have yet to be written. However we will remain in the `.active`
221- /// state until the response stream writer has completed.
222- ///
223- /// - The `EventLoopPromise` bridges the NIO and async-await worlds. It is the mechanism that we
224- /// use to run a callback when the user handler has completed. The promise is not passed to the
225- /// user handler directly. Instead it is fulfilled with the result of the async `Task` executing
226- /// the user handler using `completeWithTask(_:)`.
227- ///
228- /// - TODO: It shouldn't really be necessary to stash the `GRPCAsyncResponseStreamWriter` or the
229- /// `EventLoopPromise` in this enum value. Specifically they are never used anywhere when this
230- /// enum value is accessed. However, if we do not store them here then the tests periodically
231- /// segfault. This appears to be an bug in Swift and/or NIO since these should both have been
232- /// captured by `completeWithTask(_:)`.
233- case active(
234- PassthroughMessageSource < Request , Error > ,
235- GRPCAsyncServerCallContext ,
236- GRPCAsyncResponseStreamWriter < Response > ,
237- EventLoopPromise < Void >
238- )
207+ @usableFromInline
208+ internal struct ActiveState {
209+ /// The source backing the request stream that is being consumed by the user handler.
210+ @usableFromInline
211+ let requestStreamSource : PassthroughMessageSource < Request , Error >
212+
213+ /// The call context that was passed to the user handler.
214+ @usableFromInline
215+ let context : GRPCAsyncServerCallContext
216+
217+ /// The response stream writer that is being used by the user handler.
218+ ///
219+ /// Because this is pausable, it may contain responses after the user handler has completed
220+ /// that have yet to be written. However we will remain in the `.active` state until the
221+ /// response stream writer has completed.
222+ @usableFromInline
223+ let responseStreamWriter : GRPCAsyncResponseStreamWriter < Response >
224+
225+ /// The response headers have been sent back to the client via the interceptors.
226+ @usableFromInline
227+ var haveSentResponseHeaders : Bool = false
228+
229+ /// The promise we are using to bridge the NIO and async-await worlds.
230+ ///
231+ /// It is the mechanism that we use to run a callback when the user handler has completed.
232+ /// The promise is not passed to the user handler directly. Instead it is fulfilled with the
233+ /// result of the async `Task` executing the user handler using `completeWithTask(_:)`.
234+ ///
235+ /// - TODO: It shouldn't really be necessary to stash this promise here. Specifically it is
236+ /// never used anywhere when the `.active` enum value is accessed. However, if we do not store
237+ /// it here then the tests periodically segfault. This appears to be a reference counting bug
238+ /// in Swift and/or NIO since it should have been captured by `completeWithTask(_:)`.
239+ let _userHandlerPromise : EventLoopPromise < Void >
240+
241+ @usableFromInline
242+ internal init (
243+ requestStreamSource: PassthroughMessageSource < Request , Error > ,
244+ context: GRPCAsyncServerCallContext ,
245+ responseStreamWriter: GRPCAsyncResponseStreamWriter < Response > ,
246+ userHandlerPromise: EventLoopPromise < Void >
247+ ) {
248+ self . requestStreamSource = requestStreamSource
249+ self . context = context
250+ self . responseStreamWriter = responseStreamWriter
251+ self . _userHandlerPromise = userHandlerPromise
252+ }
253+ }
254+
255+ /// Headers have been received and an async `Task` has been created to execute the user handler.
256+ case active( ActiveState )
239257
240258 /// The handler has completed.
241259 case completed
@@ -363,15 +381,16 @@ internal final class AsyncServerHandler<
363381 )
364382
365383 // Set the state to active and bundle in all the associated data.
366- self . state = . active( requestStreamSource, context, responseStreamWriter, userHandlerPromise)
384+ self . state = . active( . init(
385+ requestStreamSource: requestStreamSource,
386+ context: context,
387+ responseStreamWriter: responseStreamWriter,
388+ userHandlerPromise: userHandlerPromise
389+ ) )
367390
368391 // Register callback for the completion of the user handler.
369392 userHandlerPromise. futureResult. whenComplete ( self . userHandlerCompleted ( _: ) )
370393
371- // Send response headers back via the interceptors.
372- // TODO: In future we may want to defer this until the first response is available from the user handler which will allow the user to set the response headers via the context.
373- self . interceptors. send ( . metadata( [ : ] ) , promise: nil )
374-
375394 // Spin up a task to call the async user handler.
376395 self . userHandlerTask = userHandlerPromise. completeWithTask {
377396 return try await withTaskCancellationHandler {
@@ -443,8 +462,8 @@ internal final class AsyncServerHandler<
443462 switch self . state {
444463 case . idle:
445464 self . handleError ( GRPCError . ProtocolViolation ( " Message received before headers " ) )
446- case let . active( requestStreamSource , _ , _ , _ ) :
447- switch requestStreamSource. yield ( request) {
465+ case let . active( activeState ) :
466+ switch activeState . requestStreamSource. yield ( request) {
448467 case . accepted( queueDepth: _) :
449468 // TODO: In future we will potentially issue a read request to the channel based on the value of `queueDepth`.
450469 break
@@ -467,8 +486,8 @@ internal final class AsyncServerHandler<
467486 switch self . state {
468487 case . idle:
469488 self . handleError ( GRPCError . ProtocolViolation ( " End of stream received before headers " ) )
470- case let . active( requestStreamSource , _ , _ , _ ) :
471- switch requestStreamSource. finish ( ) {
489+ case let . active( activeState ) :
490+ switch activeState . requestStreamSource. finish ( ) {
472491 case . accepted( queueDepth: _) :
473492 break
474493 case . dropped:
@@ -495,7 +514,14 @@ internal final class AsyncServerHandler<
495514 // The user handler cannot send responses before it has been invoked.
496515 preconditionFailure ( )
497516
498- case . active:
517+ case var . active( activeState) :
518+ if !activeState. haveSentResponseHeaders {
519+ activeState. haveSentResponseHeaders = true
520+ self . state = . active( activeState)
521+ // Send response headers back via the interceptors.
522+ self . interceptors. send ( . metadata( activeState. context. initialResponseMetadata) , promise: nil )
523+ }
524+ // Send the response back via the interceptors.
499525 self . interceptors. send ( . message( response, metadata) , promise: nil )
500526
501527 case . completed:
@@ -547,10 +573,13 @@ internal final class AsyncServerHandler<
547573 case . idle:
548574 preconditionFailure ( )
549575
550- case let . active( _ , context , _ , _ ) :
576+ case let . active( activeState ) :
551577 // Now we have drained the response stream writer from the user handler we can send end.
552578 self . state = . completed
553- self . interceptors. send ( . end( status, context. trailers) , promise: nil )
579+ self . interceptors. send (
580+ . end( status, activeState. context. trailingResponseMetadata) ,
581+ promise: nil
582+ )
554583
555584 case . completed:
556585 ( )
@@ -580,7 +609,7 @@ internal final class AsyncServerHandler<
580609 )
581610 self . interceptors. send ( . end( status, trailers) , promise: nil )
582611
583- case let . active( _ , context , _ , _ ) :
612+ case let . active( activeState ) :
584613 self . state = . completed
585614
586615 // If we have an async task, then cancel it, which will terminate the request stream from
@@ -593,8 +622,8 @@ internal final class AsyncServerHandler<
593622 if isHandlerError {
594623 ( status, trailers) = ServerErrorProcessor . processObserverError (
595624 error,
596- headers: context. headers ,
597- trailers: context. trailers ,
625+ headers: activeState . context. requestMetadata ,
626+ trailers: activeState . context. trailingResponseMetadata ,
598627 delegate: self . context. errorDelegate
599628 )
600629 } else {
0 commit comments