diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index 59beed926..557af2af1 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -16,6 +16,16 @@ import struct Foundation.URL import NIOCore import NIOHTTP1 +extension HTTPClient { + /// The maximum body size allowed, before a redirect response is cancelled. 3KB. + /// + /// Why 3KB? We feel like this is a good compromise between potentially reusing the + /// connection in HTTP/1.1 mode (if we load all data from the redirect response we can + /// reuse the connection) and not being to wasteful in the amount of data that is thrown + /// away being transferred. + fileprivate static let maxBodySizeRedirectResponse = 1024 * 3 +} + extension RequestBag { struct StateMachine { fileprivate enum State { @@ -23,7 +33,7 @@ extension RequestBag { case queued(HTTPRequestScheduler) case executing(HTTPRequestExecutor, RequestStreamState, ResponseStreamState) case finished(error: Error?) - case redirected(HTTPResponseHead, URL) + case redirected(HTTPRequestExecutor, Int, HTTPResponseHead, URL) case modifying } @@ -259,11 +269,18 @@ extension RequestBag.StateMachine { } } + enum ReceiveResponseHeadAction { + case none + case forwardResponseHead(HTTPResponseHead) + case signalBodyDemand(HTTPRequestExecutor) + case redirect(HTTPRequestExecutor, RedirectHandler, HTTPResponseHead, URL) + } + /// The response head has been received. /// /// - Parameter head: The response' head /// - Returns: Whether the response should be forwarded to the delegate. Will be `false` if the request follows a redirect. - mutating func receiveResponseHead(_ head: HTTPResponseHead) -> Bool { + mutating func receiveResponseHead(_ head: HTTPResponseHead) -> ReceiveResponseHeadAction { switch self.state { case .initialized, .queued: preconditionFailure("How can we receive a response, if the request hasn't started yet.") @@ -276,16 +293,25 @@ extension RequestBag.StateMachine { status: head.status, responseHeaders: head.headers ) { - self.state = .redirected(head, redirectURL) - return false + // If we will redirect, we need to consume the response's body ASAP, to be able to + // reuse the existing connection. We will consume a response body, if the body is + // smaller than 3kb. + switch head.contentLength { + case .some(0...(HTTPClient.maxBodySizeRedirectResponse)), .none: + self.state = .redirected(executor, 0, head, redirectURL) + return .signalBodyDemand(executor) + case .some: + self.state = .finished(error: HTTPClientError.cancelled) + return .redirect(executor, self.redirectHandler!, head, redirectURL) + } } else { self.state = .executing(executor, requestState, .buffering(.init(), next: .askExecutorForMore)) - return true + return .forwardResponseHead(head) } case .redirected: preconditionFailure("This state can only be reached after we have received a HTTP head") case .finished(error: .some): - return false + return .none case .finished(error: .none): preconditionFailure("How can the request be finished without error, before receiving response head?") case .modifying: @@ -293,7 +319,14 @@ extension RequestBag.StateMachine { } } - mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ByteBuffer? { + enum ReceiveResponseBodyAction { + case none + case forwardResponsePart(ByteBuffer) + case signalBodyDemand(HTTPRequestExecutor) + case redirect(HTTPRequestExecutor, RedirectHandler, HTTPResponseHead, URL) + } + + mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponseBodyAction { switch self.state { case .initialized, .queued: preconditionFailure("How can we receive a response body part, if the request hasn't started yet.") @@ -312,17 +345,26 @@ extension RequestBag.StateMachine { currentBuffer.append(contentsOf: buffer) } self.state = .executing(executor, requestState, .buffering(currentBuffer, next: next)) - return nil + return .none case .executing(let executor, let requestState, .waitingForRemote): var buffer = buffer let first = buffer.removeFirst() self.state = .executing(executor, requestState, .buffering(buffer, next: .askExecutorForMore)) - return first - case .redirected: - // ignore body - return nil + return .forwardResponsePart(first) + case .redirected(let executor, var receivedBytes, let head, let redirectURL): + let partsLength = buffer.reduce(into: 0) { $0 += $1.readableBytes } + receivedBytes += partsLength + + if receivedBytes > HTTPClient.maxBodySizeRedirectResponse { + self.state = .finished(error: HTTPClientError.cancelled) + return .redirect(executor, self.redirectHandler!, head, redirectURL) + } else { + self.state = .redirected(executor, receivedBytes, head, redirectURL) + return .signalBodyDemand(executor) + } + case .finished(error: .some): - return nil + return .none case .finished(error: .none): preconditionFailure("How can the request be finished without error, before receiving response head?") case .modifying: @@ -368,7 +410,7 @@ extension RequestBag.StateMachine { self.state = .executing(executor, requestState, .buffering(newChunks, next: .eof)) return .consume(first) - case .redirected(let head, let redirectURL): + case .redirected(_, _, let head, let redirectURL): self.state = .finished(error: nil) return .redirect(self.redirectHandler!, head, redirectURL) @@ -529,3 +571,12 @@ extension RequestBag.StateMachine { } } } + +extension HTTPResponseHead { + var contentLength: Int? { + guard let header = self.headers.first(name: "content-length") else { + return nil + } + return Int(header) + } +} diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index 9a40e9ff5..b4aeef0e7 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -196,33 +196,49 @@ final class RequestBag { self.task.eventLoop.assertInEventLoop() // runs most likely on channel eventLoop - let forwardToDelegate = self.state.receiveResponseHead(head) + switch self.state.receiveResponseHead(head) { + case .none: + break - guard forwardToDelegate else { return } + case .signalBodyDemand(let executor): + executor.demandResponseBodyStream(self) - self.delegate.didReceiveHead(task: self.task, head) - .hop(to: self.task.eventLoop) - .whenComplete { result in - // After the head received, let's start to consume body data - self.consumeMoreBodyData0(resultOfPreviousConsume: result) - } + case .redirect(let executor, let handler, let head, let newURL): + handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + executor.cancelRequest(self) + + case .forwardResponseHead(let head): + self.delegate.didReceiveHead(task: self.task, head) + .hop(to: self.task.eventLoop) + .whenComplete { result in + // After the head received, let's start to consume body data + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } } private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { self.task.eventLoop.assertInEventLoop() - let maybeForwardBuffer = self.state.receiveResponseBodyParts(buffer) + switch self.state.receiveResponseBodyParts(buffer) { + case .none: + break - guard let forwardBuffer = maybeForwardBuffer else { - return - } + case .signalBodyDemand(let executor): + executor.demandResponseBodyStream(self) - self.delegate.didReceiveBodyPart(task: self.task, forwardBuffer) - .hop(to: self.task.eventLoop) - .whenComplete { result in - // on task el - self.consumeMoreBodyData0(resultOfPreviousConsume: result) - } + case .redirect(let executor, let handler, let head, let newURL): + handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + executor.cancelRequest(self) + + case .forwardResponsePart(let part): + self.delegate.didReceiveBodyPart(task: self.task, part) + .hop(to: self.task.eventLoop) + .whenComplete { result in + // on task el + self.consumeMoreBodyData0(resultOfPreviousConsume: result) + } + } } private func succeedRequest0(_ buffer: CircularBuffer?) { diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift index 0a8f850ad..74c68fd1f 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests+XCTest.swift @@ -34,6 +34,9 @@ extension RequestBagTests { ("testChannelBecomingWritableDoesntCrashCancelledTask", testChannelBecomingWritableDoesntCrashCancelledTask), ("testHTTPUploadIsCancelledEvenThoughRequestSucceeds", testHTTPUploadIsCancelledEvenThoughRequestSucceeds), ("testRaceBetweenConnectionCloseAndDemandMoreData", testRaceBetweenConnectionCloseAndDemandMoreData), + ("testRedirectWith3KBBody", testRedirectWith3KBBody), + ("testRedirectWith4KBBodyAnnouncedInResponseHead", testRedirectWith4KBBodyAnnouncedInResponseHead), + ("testRedirectWith4KBBodyNotAnnouncedInResponseHead", testRedirectWith4KBBodyNotAnnouncedInResponseHead), ] } } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index ed50ae02d..c80f8846b 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -496,6 +496,199 @@ final class RequestBagTests: XCTestCase { XCTAssertNoThrow(try XCTUnwrap(delegate.backpressurePromise).succeed(())) XCTAssertEqual(delegate.hitDidReceiveResponse, 1) } + + func testRedirectWith3KBBody() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + var redirectTriggered = false + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertNil(delegate.backpressurePromise) + XCTAssertTrue(executor.signalledDemandForResponseBody) + executor.resetResponseStreamDemandSignal() + + // "foo" is forwarded for consumption. We expect the RequestBag to consume "foo" with the + // delegate and call demandMoreBody afterwards. + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 0, count: 1024)]) + XCTAssertTrue(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 1, count: 1024)]) + XCTAssertTrue(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.succeedRequest([ByteBuffer(repeating: 2, count: 1024)]) + XCTAssertFalse(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveResponse, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertTrue(redirectTriggered) + } + + func testRedirectWith4KBBodyAnnouncedInResponseHead() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + var redirectTriggered = false + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertNil(delegate.backpressurePromise) + XCTAssertFalse(executor.signalledDemandForResponseBody) + XCTAssertTrue(executor.isCancelled) + + XCTAssertTrue(redirectTriggered) + } + + func testRedirectWith4KBBodyNotAnnouncedInResponseHead() { + let embeddedEventLoop = EmbeddedEventLoop() + defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let logger = Logger(label: "test") + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org")) + guard let request = maybeRequest else { return XCTFail("Expected to have a request") } + + let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) + var maybeRequestBag: RequestBag? + var redirectTriggered = false + XCTAssertNoThrow(maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + )) + guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } + + let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + executor.runRequest(bag) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertNil(delegate.backpressurePromise) + XCTAssertTrue(executor.signalledDemandForResponseBody) + executor.resetResponseStreamDemandSignal() + + // "foo" is forwarded for consumption. We expect the RequestBag to consume "foo" with the + // delegate and call demandMoreBody afterwards. + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 0, count: 2024)]) + XCTAssertTrue(executor.signalledDemandForResponseBody) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertFalse(executor.isCancelled) + XCTAssertFalse(executor.signalledDemandForResponseBody) + bag.receiveResponseBodyParts([ByteBuffer(repeating: 1, count: 2024)]) + XCTAssertFalse(executor.signalledDemandForResponseBody) + XCTAssertTrue(executor.isCancelled) + XCTAssertEqual(delegate.hitDidReceiveBodyPart, 0) + XCTAssertNil(delegate.backpressurePromise) + executor.resetResponseStreamDemandSignal() + + XCTAssertTrue(redirectTriggered) + } } class UploadCountingDelegate: HTTPClientResponseDelegate {