Skip to content

Correctly handle Connection: close with streaming #598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,22 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
case .close:
context.close(promise: nil)
oldRequest.succeedRequest(buffer)
case .sendRequestEnd(let writePromise):
case .sendRequestEnd(let writePromise, let shouldClose):
let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self)
// We need to defer succeeding the old request to avoid ordering issues
writePromise.futureResult.whenComplete { result in
writePromise.futureResult.hop(to: context.eventLoop).whenComplete { result in
switch result {
case .success:
// If our final action was `sendRequestEnd`, that means we've already received
// the complete response. As a result, once we've uploaded all the body parts
// we need to tell the pool that the connection is idle.
self.connection.taskCompleted()
// we need to tell the pool that the connection is idle or, if we were asked to
// close when we're done, send the close. Either way, we then succeed the request
if shouldClose {
context.close(promise: nil)
} else {
self.connection.taskCompleted()
}

oldRequest.succeedRequest(buffer)
case .failure(let error):
oldRequest.fail(error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ struct HTTP1ConnectionStateMachine {
/// as soon as we wrote the request end onto the wire.
///
/// The promise is an optional write promise.
case sendRequestEnd(EventLoopPromise<Void>?)
///
/// `shouldClose` records whether we have attached a Connection: close header to this request, and so the connection should
/// be terminated
case sendRequestEnd(EventLoopPromise<Void>?, shouldClose: Bool)
/// Inform an observer that the connection has become idle
case informConnectionIsIdle
}
Expand Down Expand Up @@ -413,7 +416,7 @@ extension HTTP1ConnectionStateMachine.State {
newFinalAction = .close
case .sendRequestEnd(let writePromise):
self = .idle
newFinalAction = .sendRequestEnd(writePromise)
newFinalAction = .sendRequestEnd(writePromise, shouldClose: close)
case .none:
self = .idle
newFinalAction = close ? .close : .informConnectionIsIdle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equata
switch (lhs, rhs) {
case (.close, .close):
return true
case (sendRequestEnd(let lhsPromise), sendRequestEnd(let rhsPromise)):
return lhsPromise?.futureResult == rhsPromise?.futureResult
case (sendRequestEnd(let lhsPromise, let lhsShouldClose), sendRequestEnd(let rhsPromise, let rhsShouldClose)):
return lhsPromise?.futureResult == rhsPromise?.futureResult && lhsShouldClose == rhsShouldClose
case (informConnectionIsIdle, informConnectionIsIdle):
return true
default:
Expand Down
26 changes: 26 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,32 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler {
}
}

final class ExpectClosureServerHandler: ChannelInboundHandler {
typealias InboundIn = HTTPServerRequestPart
typealias OutboundOut = HTTPServerResponsePart

private let onClosePromise: EventLoopPromise<Void>

init(onClosePromise: EventLoopPromise<Void>) {
self.onClosePromise = onClosePromise
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
switch self.unwrapInboundIn(data) {
case .head:
let head = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "0"])
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
case .body, .end:
()
}
}

func channelInactive(context: ChannelHandlerContext) {
self.onClosePromise.succeed(())
}
}

struct EventLoopFutureTimeoutError: Error {}

extension EventLoopFuture {
Expand Down
1 change: 1 addition & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ extension HTTPClientTests {
("testBiDirectionalStreaming", testBiDirectionalStreaming),
("testBiDirectionalStreamingEarly200", testBiDirectionalStreamingEarly200),
("testBiDirectionalStreamingEarly200DoesntPreventUsFromSendingMoreRequests", testBiDirectionalStreamingEarly200DoesntPreventUsFromSendingMoreRequests),
("testCloseConnectionAfterEarly2XXWhenStreaming", testCloseConnectionAfterEarly2XXWhenStreaming),
("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting),
("testFileDownloadChunked", testFileDownloadChunked),
("testCloseWhileBackpressureIsExertedIsFine", testCloseWhileBackpressureIsExertedIsFine),
Expand Down
54 changes: 54 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,60 @@ class HTTPClientTests: XCTestCase {
XCTAssertNoThrow(try future2.wait())
}

// This test validates that we correctly close the connection after our body completes when we've streamed a
// body and received the 2XX response _before_ we finished our stream.
func testCloseConnectionAfterEarly2XXWhenStreaming() {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }

let onClosePromise = eventLoopGroup.next().makePromise(of: Void.self)
let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in ExpectClosureServerHandler(onClosePromise: onClosePromise) }
defer { XCTAssertNoThrow(try httpBin.shutdown()) }

let writeEL = eventLoopGroup.next()

let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup))
defer { XCTAssertNoThrow(try httpClient.syncShutdown()) }

let body: HTTPClient.Body = .stream { writer in
let finalPromise = writeEL.makePromise(of: Void.self)

func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) {
// always invoke from the wrong el to test thread safety
writeEL.preconditionInEventLoop()

if index >= 30 {
return finalPromise.succeed(())
}

let sent = ByteBuffer(integer: index)
writer.write(.byteBuffer(sent)).whenComplete { result in
switch result {
case .success:
writeEL.execute {
writeLoop(writer, index: index + 1)
}

case .failure(let error):
finalPromise.fail(error)
}
}
}

writeEL.execute {
writeLoop(writer, index: 0)
}

return finalPromise.futureResult
}

let headers = HTTPHeaders([("Connection", "close")])
let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)", headers: headers, body: body)
let future = httpClient.execute(request: request)
XCTAssertNoThrow(try future.wait())
XCTAssertNoThrow(try onClosePromise.futureResult.wait())
}

func testSynchronousHandshakeErrorReporting() throws {
// This only affects cases where we use NIOSSL.
guard !isTestingNIOTS() else { return }
Expand Down