diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift index 1092bbd9c..3450cfcc8 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncBidirectionalStreamingCall.swift @@ -40,6 +40,10 @@ public struct GRPCAsyncBidirectionalStreamingCall { // MARK: - Response Parts /// The initial metadata returned from the server. + /// + /// - Important: The initial metadata will only be available when the first response has been + /// received. However, it is not necessary for the response to have been consumed before reading + /// this property. public var initialMetadata: HPACKHeaders { // swiftformat:disable:next redundantGet get async throws { diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift index c755cf557..b6b33b74e 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncClientStreamingCall.swift @@ -36,6 +36,8 @@ public struct GRPCAsyncClientStreamingCall { // MARK: - Response Parts /// The initial metadata returned from the server. + /// + /// - Important: The initial metadata will only be available when the response has been received. public var initialMetadata: HPACKHeaders { // swiftformat:disable:next redundantGet get async throws { diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift index 2fe2aa2cc..048b3eb16 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerCallContext.swift @@ -34,8 +34,8 @@ import NIOHPACK public final class GRPCAsyncServerCallContext { private let lock = Lock() - /// Request headers for this request. - public let headers: HPACKHeaders + /// Metadata for this request. + public let requestMetadata: HPACKHeaders /// The logger used for this call. public var logger: Logger { @@ -83,18 +83,34 @@ public final class GRPCAsyncServerCallContext { @usableFromInline internal let userInfoRef: Ref - /// Metadata to return at the end of the RPC. If this is required it should be updated before - /// the `responsePromise` or `statusPromise` is fulfilled. - public var trailers: HPACKHeaders { + /// Metadata to return at the start of the RPC. + /// + /// - Important: If this is required it should be updated _before_ the first response is sent via + /// the response stream writer. Any updates made after the first response will be ignored. + public var initialResponseMetadata: HPACKHeaders { + get { self.lock.withLock { + return self._initialResponseMetadata + } } + set { self.lock.withLock { + self._initialResponseMetadata = newValue + } } + } + + private var _initialResponseMetadata: HPACKHeaders = [:] + + /// Metadata to return at the end of the RPC. + /// + /// If this is required it should be updated before returning from the handler. + public var trailingResponseMetadata: HPACKHeaders { get { self.lock.withLock { - return self._trailers + return self._trailingResponseMetadata } } set { self.lock.withLock { - self._trailers = newValue + self._trailingResponseMetadata = newValue } } } - private var _trailers: HPACKHeaders = [:] + private var _trailingResponseMetadata: HPACKHeaders = [:] @inlinable internal init( @@ -102,7 +118,7 @@ public final class GRPCAsyncServerCallContext { logger: Logger, userInfoRef: Ref ) { - self.headers = headers + self.requestMetadata = headers self.userInfoRef = userInfoRef self._logger = logger } diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift index a23dc64d1..2de59b807 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift @@ -204,38 +204,56 @@ internal final class AsyncServerHandler< /// No headers have been received. case idle - /// Headers have been received, and an async `Task` has been created to execute the user - /// handler. - /// - /// The inputs to the user handler are held in the associated data of this enum value: - /// - /// - The `PassthroughMessageSource` is the source backing the request stream that is being - /// consumed by the user handler. - /// - /// - The `GRPCAsyncServerContext` is a reference to the context that was passed to the user - /// handler. - /// - /// - The `GRPCAsyncResponseStreamWriter` is the response stream writer that is being written to - /// by the user handler. Because this is pausable, it may contain responses after the user - /// handler has completed that have yet to be written. However we will remain in the `.active` - /// state until the response stream writer has completed. - /// - /// - The `EventLoopPromise` bridges the NIO and async-await worlds. It is the mechanism that we - /// use to run a callback when the user handler has completed. The promise is not passed to the - /// user handler directly. Instead it is fulfilled with the result of the async `Task` executing - /// the user handler using `completeWithTask(_:)`. - /// - /// - TODO: It shouldn't really be necessary to stash the `GRPCAsyncResponseStreamWriter` or the - /// `EventLoopPromise` in this enum value. Specifically they are never used anywhere when this - /// enum value is accessed. However, if we do not store them here then the tests periodically - /// segfault. This appears to be an bug in Swift and/or NIO since these should both have been - /// captured by `completeWithTask(_:)`. - case active( - PassthroughMessageSource, - GRPCAsyncServerCallContext, - GRPCAsyncResponseStreamWriter, - EventLoopPromise - ) + @usableFromInline + internal struct ActiveState { + /// The source backing the request stream that is being consumed by the user handler. + @usableFromInline + let requestStreamSource: PassthroughMessageSource + + /// The call context that was passed to the user handler. + @usableFromInline + let context: GRPCAsyncServerCallContext + + /// The response stream writer that is being used by the user handler. + /// + /// Because this is pausable, it may contain responses after the user handler has completed + /// that have yet to be written. However we will remain in the `.active` state until the + /// response stream writer has completed. + @usableFromInline + let responseStreamWriter: GRPCAsyncResponseStreamWriter + + /// The response headers have been sent back to the client via the interceptors. + @usableFromInline + var haveSentResponseHeaders: Bool = false + + /// The promise we are using to bridge the NIO and async-await worlds. + /// + /// It is the mechanism that we use to run a callback when the user handler has completed. + /// The promise is not passed to the user handler directly. Instead it is fulfilled with the + /// result of the async `Task` executing the user handler using `completeWithTask(_:)`. + /// + /// - TODO: It shouldn't really be necessary to stash this promise here. Specifically it is + /// never used anywhere when the `.active` enum value is accessed. However, if we do not store + /// it here then the tests periodically segfault. This appears to be a reference counting bug + /// in Swift and/or NIO since it should have been captured by `completeWithTask(_:)`. + let _userHandlerPromise: EventLoopPromise + + @usableFromInline + internal init( + requestStreamSource: PassthroughMessageSource, + context: GRPCAsyncServerCallContext, + responseStreamWriter: GRPCAsyncResponseStreamWriter, + userHandlerPromise: EventLoopPromise + ) { + self.requestStreamSource = requestStreamSource + self.context = context + self.responseStreamWriter = responseStreamWriter + self._userHandlerPromise = userHandlerPromise + } + } + + /// Headers have been received and an async `Task` has been created to execute the user handler. + case active(ActiveState) /// The handler has completed. case completed @@ -363,15 +381,16 @@ internal final class AsyncServerHandler< ) // Set the state to active and bundle in all the associated data. - self.state = .active(requestStreamSource, context, responseStreamWriter, userHandlerPromise) + self.state = .active(.init( + requestStreamSource: requestStreamSource, + context: context, + responseStreamWriter: responseStreamWriter, + userHandlerPromise: userHandlerPromise + )) // Register callback for the completion of the user handler. userHandlerPromise.futureResult.whenComplete(self.userHandlerCompleted(_:)) - // Send response headers back via the interceptors. - // TODO: In future we may want to defer this until the first response is available from the user handler which will allow the user to set the response headers via the context. - self.interceptors.send(.metadata([:]), promise: nil) - // Spin up a task to call the async user handler. self.userHandlerTask = userHandlerPromise.completeWithTask { return try await withTaskCancellationHandler { @@ -443,8 +462,8 @@ internal final class AsyncServerHandler< switch self.state { case .idle: self.handleError(GRPCError.ProtocolViolation("Message received before headers")) - case let .active(requestStreamSource, _, _, _): - switch requestStreamSource.yield(request) { + case let .active(activeState): + switch activeState.requestStreamSource.yield(request) { case .accepted(queueDepth: _): // TODO: In future we will potentially issue a read request to the channel based on the value of `queueDepth`. break @@ -467,8 +486,8 @@ internal final class AsyncServerHandler< switch self.state { case .idle: self.handleError(GRPCError.ProtocolViolation("End of stream received before headers")) - case let .active(requestStreamSource, _, _, _): - switch requestStreamSource.finish() { + case let .active(activeState): + switch activeState.requestStreamSource.finish() { case .accepted(queueDepth: _): break case .dropped: @@ -495,7 +514,14 @@ internal final class AsyncServerHandler< // The user handler cannot send responses before it has been invoked. preconditionFailure() - case .active: + case var .active(activeState): + if !activeState.haveSentResponseHeaders { + activeState.haveSentResponseHeaders = true + self.state = .active(activeState) + // Send response headers back via the interceptors. + self.interceptors.send(.metadata(activeState.context.initialResponseMetadata), promise: nil) + } + // Send the response back via the interceptors. self.interceptors.send(.message(response, metadata), promise: nil) case .completed: @@ -547,10 +573,13 @@ internal final class AsyncServerHandler< case .idle: preconditionFailure() - case let .active(_, context, _, _): + case let .active(activeState): // Now we have drained the response stream writer from the user handler we can send end. self.state = .completed - self.interceptors.send(.end(status, context.trailers), promise: nil) + self.interceptors.send( + .end(status, activeState.context.trailingResponseMetadata), + promise: nil + ) case .completed: () @@ -580,7 +609,7 @@ internal final class AsyncServerHandler< ) self.interceptors.send(.end(status, trailers), promise: nil) - case let .active(_, context, _, _): + case let .active(activeState): self.state = .completed // If we have an async task, then cancel it, which will terminate the request stream from @@ -593,8 +622,8 @@ internal final class AsyncServerHandler< if isHandlerError { (status, trailers) = ServerErrorProcessor.processObserverError( error, - headers: context.headers, - trailers: context.trailers, + headers: activeState.context.requestMetadata, + trailers: activeState.context.trailingResponseMetadata, delegate: self.context.errorDelegate ) } else { diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift index 93e8a2a9e..671410689 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerStreamingCall.swift @@ -39,6 +39,10 @@ public struct GRPCAsyncServerStreamingCall { // MARK: - Response Parts /// The initial metadata returned from the server. + /// + /// - Important: The initial metadata will only be available when the first response has been + /// received. However, it is not necessary for the response to have been consumed before reading + /// this property. public var initialMetadata: HPACKHeaders { // swiftformat:disable:next redundantGet get async throws { diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift index 92141f59c..71efaabf1 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncUnaryCall.swift @@ -39,6 +39,8 @@ public struct GRPCAsyncUnaryCall { // MARK: - Response Parts /// The initial metadata returned from the server. + /// + /// - Important: The initial metadata will only be available when the response has been received. public var initialMetadata: HPACKHeaders { // swiftformat:disable:next redundantGet get async throws { diff --git a/Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift b/Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift index 9eab2dc6c..52ce2bc4d 100644 --- a/Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift +++ b/Tests/GRPCTests/GRPCAsyncServerHandlerTests.swift @@ -74,8 +74,6 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { ) handler.receiveMetadata([:]) - await assertThat(self.recorder.metadata, .is([:])) - handler.receiveMessage(ByteBuffer(string: "1")) handler.receiveMessage(ByteBuffer(string: "2")) handler.receiveMessage(ByteBuffer(string: "3")) @@ -86,6 +84,7 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { handler.finish() + await assertThat(self.recorder.metadata, .is([:])) await assertThat( self.recorder.messages, .is([ByteBuffer(string: "1"), ByteBuffer(string: "2"), ByteBuffer(string: "3")]) @@ -145,14 +144,46 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { await assertThat(self.recorder.messageMetadata.map { $0.compress }, .is([false, false, false])) } } + func testResponseHeadersAndTrailersSentFromContext() { XCTAsyncTest { + let handler = self.makeHandler { _, responseStreamWriter, context in + context.initialResponseMetadata = ["pontiac": "bandit"] + try await responseStreamWriter.send("1") + context.trailingResponseMetadata = ["disco": "strangler"] + } + handler.receiveMetadata([:]) + handler.receiveEnd() + + // Wait for tasks to finish. + await handler.userHandlerTask?.value + + await assertThat(self.recorder.metadata, .is(["pontiac": "bandit"])) + await assertThat(self.recorder.trailers, .is(["disco": "strangler"])) + } } + + func testResponseHeadersDroppedIfSetAfterFirstResponse() { XCTAsyncTest { + let handler = self.makeHandler { _, responseStreamWriter, context in + try await responseStreamWriter.send("1") + context.initialResponseMetadata = ["pontiac": "bandit"] + context.trailingResponseMetadata = ["disco": "strangler"] + } + handler.receiveMetadata([:]) + handler.receiveEnd() + + // Wait for tasks to finish. + await handler.userHandlerTask?.value + + await assertThat(self.recorder.metadata, .is([:])) + await assertThat(self.recorder.trailers, .is(["disco": "strangler"])) + } } + func testTaskOnlyCreatedAfterHeaders() { XCTAsyncTest { let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:)) - await assertThat(handler.userHandlerTask, .is(.nil())) + await assertThat(handler.userHandlerTask, .nil()) handler.receiveMetadata([:]) - await assertThat(handler.userHandlerTask, .is(.notNil())) + await assertThat(handler.userHandlerTask, .notNil()) } } func testThrowingDeserializer() { XCTAsyncTest { @@ -165,18 +196,12 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { ) handler.receiveMetadata([:]) - - // Wait for the async user function to have processed the metadata. - try self.recorder.recordedMetadataPromise.futureResult.wait() - - await assertThat(self.recorder.metadata, .is([:])) - - let buffer = ByteBuffer(string: "hello") - handler.receiveMessage(buffer) + handler.receiveMessage(ByteBuffer(string: "hello")) // Wait for tasks to finish. await handler.userHandlerTask?.value + await assertThat(self.recorder.metadata, .nil()) await assertThat(self.recorder.messages, .isEmpty()) await assertThat(self.recorder.status, .notNil(.hasCode(.internalError))) } } @@ -191,15 +216,13 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { ) handler.receiveMetadata([:]) - await assertThat(self.recorder.metadata, .is([:])) - - let buffer = ByteBuffer(string: "hello") - handler.receiveMessage(buffer) + handler.receiveMessage(ByteBuffer(string: "hello")) handler.receiveEnd() // Wait for tasks to finish. await handler.userHandlerTask?.value + await assertThat(self.recorder.metadata, .is([:])) await assertThat(self.recorder.messages, .isEmpty()) await assertThat(self.recorder.status, .notNil(.hasCode(.internalError))) } } @@ -213,28 +236,22 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { // Wait for tasks to finish. await handler.userHandlerTask?.value - await assertThat(self.recorder.metadata, .is(.nil())) + await assertThat(self.recorder.metadata, .nil()) await assertThat(self.recorder.messages, .isEmpty()) await assertThat(self.recorder.status, .notNil(.hasCode(.internalError))) } } - // TODO: Running this 1000 times shows up a segfault in NIO event loop group. func testReceiveMultipleHeaders() { XCTAsyncTest { let handler = self .makeHandler(observer: self.neverReceivesMessage(requests:responseStreamWriter:context:)) handler.receiveMetadata([:]) - - // Wait for the async user function to have processed the metadata. - try self.recorder.recordedMetadataPromise.futureResult.wait() - - await assertThat(self.recorder.metadata, .is([:])) - handler.receiveMetadata([:]) // Wait for tasks to finish. await handler.userHandlerTask?.value + await assertThat(self.recorder.metadata, .nil()) await assertThat(self.recorder.messages, .isEmpty()) await assertThat(self.recorder.status, .notNil(.hasCode(.internalError))) } } @@ -244,26 +261,22 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { .makeHandler(observer: self.neverCalled(requests:responseStreamWriter:context:)) handler.finish() - await assertThat(self.recorder.metadata, .is(.nil())) + await assertThat(self.recorder.metadata, .nil()) await assertThat(self.recorder.messages, .isEmpty()) - await assertThat(self.recorder.status, .is(.nil())) - await assertThat(self.recorder.trailers, .is(.nil())) + await assertThat(self.recorder.status, .nil()) + await assertThat(self.recorder.trailers, .nil()) } } func testFinishAfterHeaders() { XCTAsyncTest { let handler = self.makeHandler(observer: self.echo(requests:responseStreamWriter:context:)) - handler.receiveMetadata([:]) - - // Wait for the async user function to have processed the metadata. - try self.recorder.recordedMetadataPromise.futureResult.wait() - - await assertThat(self.recorder.metadata, .is([:])) + handler.receiveMetadata([:]) handler.finish() // Wait for tasks to finish. await handler.userHandlerTask?.value + await assertThat(self.recorder.metadata, .nil()) await assertThat(self.recorder.messages, .isEmpty()) await assertThat(self.recorder.status, .notNil(.hasCode(.unavailable))) await assertThat(self.recorder.trailers, .is([:])) @@ -304,8 +317,6 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { await assertThat(self.recorder.status, .notNil(.hasCode(.unknown))) } } - // TODO: We should be consistent about where we put the tasks... maybe even use a task group to simplify cancellation (unless they both go in the enum state which might be better). - func testResponseStreamDrain() { XCTAsyncTest { // Set up echo handler. let handler = self.makeHandler( @@ -317,14 +328,14 @@ class AsyncServerHandlerTests: ServerHandlerTestCaseBase { // Send two requests and end, pausing the writer in the middle. switch handler.state { - case let .active(_, _, responseStreamWriter, promise): + case let .active(activeState): handler.receiveMessage(ByteBuffer(string: "diaz")) - await responseStreamWriter.asyncWriter.toggleWritability() + await activeState.responseStreamWriter.asyncWriter.toggleWritability() handler.receiveMessage(ByteBuffer(string: "santiago")) handler.receiveEnd() - await responseStreamWriter.asyncWriter.toggleWritability() + await activeState.responseStreamWriter.asyncWriter.toggleWritability() await handler.userHandlerTask?.value - _ = try await promise.futureResult.get() + _ = try await activeState._userHandlerPromise.futureResult.get() default: XCTFail("Unexpected handler state: \(handler.state)") }