diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index c15b64a1e..3cdfeeee7 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -419,7 +419,26 @@ public class ResponseAccumulator: HTTPClientResponseDelegate { /// `HTTPClientResponseDelegate` allows an implementation to receive notifications about request processing and to control how response parts are processed. /// You can implement this protocol if you need fine-grained control over an HTTP request/response, for example, if you want to inspect the response /// headers before deciding whether to accept a response body, or if you want to stream your request body. Pass an instance of your conforming -/// class to the `HTTPClient.execute()` method and this package will call each delegate method appropriately as the request takes place. +/// class to the `HTTPClient.execute()` method and this package will call each delegate method appropriately as the request takes place./ +/// +/// ### Backpressure +/// +/// A `HTTPClientResponseDelegate` can be used to exert backpressure on the server response. This is achieved by way of the futures returned from +/// `didReceiveHead` and `didReceiveBodyPart`. The following functions are part of the "backpressure system" in the delegate: +/// +/// - `didReceiveHead` +/// - `didReceiveBodyPart` +/// - `didFinishRequest` +/// - `didReceiveError` +/// +/// The first three methods are strictly _exclusive_, with that exclusivity managed by the futures returned by `didReceiveHead` and +/// `didReceiveBodyPart`. What this means is that until the returned future is completed, none of these three methods will be called +/// again. This allows delegates to rate limit the server to a capacity it can manage. `didFinishRequest` does not return a future, +/// as we are expecting no more data from the server at this time. +/// +/// `didReceiveError` is somewhat special: it signals the end of this regime. `didRecieveError` is not exclusive: it may be called at +/// any time, even if a returned future is not yet completed. `didReceiveError` is terminal, meaning that once it has been called none +/// of these four methods will be called again. This can be used as a signal to abandon all outstanding work. /// /// - note: This delegate is strongly held by the `HTTPTaskHandler` /// for the duration of the `Request` processing and will be @@ -463,6 +482,11 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// You must return an `EventLoopFuture` that you complete when you have finished processing the body part. /// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`. /// + /// This function will not be called until the future returned by `didReceiveHead` has completed. + /// + /// This function will not be called for subsequent body parts until the previous future returned by a + /// call to this function completes. + /// /// - parameters: /// - task: Current request context. /// - buffer: Received body `Part`. @@ -471,6 +495,10 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// Called when error was thrown during request execution. Will be called zero or one time only. Request processing will be stopped after that. /// + /// This function may be called at any time: it does not respect the backpressure exerted by `didReceiveHead` and `didReceiveBodyPart`. + /// All outstanding work may be cancelled when this is received. Once called, no further calls will be made to `didReceiveHead`, `didReceiveBodyPart`, + /// or `didFinishRequest`. + /// /// - parameters: /// - task: Current request context. /// - error: Error that occured during response processing. @@ -478,6 +506,9 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// Called when the complete HTTP request is finished. You must return an instance of your `Response` associated type. Will be called once, except if an error occurred. /// + /// This function will not be called until all futures returned by `didReceiveHead` and `didReceiveBodyPart` have completed. Once called, + /// no further calls will be made to `didReceiveHead`, `didReceiveBodyPart`, or `didReceiveError`. + /// /// - parameters: /// - task: Current request context. /// - returns: Result of processing. @@ -678,6 +709,7 @@ internal class TaskHandler: RemovableChann case bodySentWaitingResponseHead case bodySentResponseHeadReceived case redirected(HTTPResponseHead, URL) + case bufferedEnd case endOrError } @@ -688,10 +720,11 @@ internal class TaskHandler: RemovableChann let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request. var state: State = .idle + var responseReadBuffer: ResponseReadBuffer = ResponseReadBuffer() var expectedBodyLength: Int? var actualBodyLength: Int = 0 var pendingRead = false - var mayRead = true + var outstandingDelegateRead = false var closing = false { didSet { assert(self.closing || !oldValue, @@ -861,10 +894,12 @@ extension TaskHandler: ChannelDuplexHandler { preconditionFailure("should not happen, state is \(self.state)") case .redirected: break - case .endOrError: + case .bufferedEnd, .endOrError: // If the state is .endOrError, it means that request was failed and there is nothing to do here: // we cannot write .end since channel is most likely closed, and we should not fail the future, // since the task would already be failed, no need to fail the writer too. + // If the state is .bufferedEnd the issue is the same, we just haven't fully delivered the response to + // the user yet. return context.eventLoop.makeSucceededFuture(()) } @@ -937,7 +972,7 @@ extension TaskHandler: ChannelDuplexHandler { } self.actualBodyLength += part.readableBytes context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) - case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .endOrError: + case .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .bufferedEnd, .endOrError: let error = HTTPClientError.writeAfterRequestSent self.errorCaught(context: context, error: error) promise.fail(error) @@ -945,11 +980,11 @@ extension TaskHandler: ChannelDuplexHandler { } public func read(context: ChannelHandlerContext) { - if self.mayRead { + if self.outstandingDelegateRead { + self.pendingRead = true + } else { self.pendingRead = false context.read() - } else { - self.pendingRead = true } } @@ -968,7 +1003,7 @@ extension TaskHandler: ChannelDuplexHandler { case .sendingBodyResponseHeadReceived, .bodySentResponseHeadReceived, .redirected: // should be prevented by NIO HTTP1 pipeline, aee testHTTPResponseDoubleHead preconditionFailure("should not happen") - case .endOrError: + case .bufferedEnd, .endOrError: return } @@ -979,26 +1014,18 @@ extension TaskHandler: ChannelDuplexHandler { if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) { self.state = .redirected(head, redirectURL) } else { - self.mayRead = false - self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) - .whenComplete { result in - self.handleBackpressureResult(context: context, result: result) - } + self.handleReadForDelegate(response, context: context) } - case .body(let body): + case .body: switch self.state { - case .redirected, .endOrError: + case .redirected, .bufferedEnd, .endOrError: break default: - self.mayRead = false - self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart) - .whenComplete { result in - self.handleBackpressureResult(context: context, result: result) - } + self.handleReadForDelegate(response, context: context) } case .end: switch self.state { - case .endOrError: + case .bufferedEnd, .endOrError: break case .redirected(let head, let redirectURL): self.state = .endOrError @@ -1006,18 +1033,67 @@ extension TaskHandler: ChannelDuplexHandler { self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise) } default: - self.state = .endOrError - self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest) + self.state = .bufferedEnd + self.handleReadForDelegate(response, context: context) + } + } + } + + private func handleReadForDelegate(_ read: HTTPClientResponsePart, context: ChannelHandlerContext) { + if self.outstandingDelegateRead { + self.responseReadBuffer.appendPart(read) + return + } + + // No outstanding delegate read, so we can deliver this directly. + self.deliverReadToDelegate(read, context: context) + } + + private func deliverReadToDelegate(_ read: HTTPClientResponsePart, context: ChannelHandlerContext) { + precondition(!self.outstandingDelegateRead) + self.outstandingDelegateRead = true + + if case .endOrError = self.state { + // No further read delivery should occur, we already delivered an error. + return + } + + switch read { + case .head(let head): + self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead) + .whenComplete { result in + self.handleBackpressureResult(context: context, result: result) + } + case .body(let body): + self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart) + .whenComplete { result in + self.handleBackpressureResult(context: context, result: result) + } + case .end: + self.state = .endOrError + self.outstandingDelegateRead = false + + if self.pendingRead { + // We must read here, as we will be removed from the channel shortly. + self.pendingRead = false + context.read() } + + self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest) } } private func handleBackpressureResult(context: ChannelHandlerContext, result: Result) { context.eventLoop.assertInEventLoop() + self.outstandingDelegateRead = false + switch result { case .success: - self.mayRead = true - if self.pendingRead { + if let nextRead = self.responseReadBuffer.nextRead() { + // We can deliver this directly. + self.deliverReadToDelegate(nextRead, context: context) + } else if self.pendingRead { + self.pendingRead = false context.read() } case .failure(let error): @@ -1046,7 +1122,7 @@ extension TaskHandler: ChannelDuplexHandler { switch self.state { case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected: self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed) - case .endOrError: + case .bufferedEnd, .endOrError: break } context.fireChannelInactive() @@ -1056,7 +1132,7 @@ extension TaskHandler: ChannelDuplexHandler { switch error { case NIOSSLError.uncleanShutdown: switch self.state { - case .endOrError: + case .bufferedEnd, .endOrError: /// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection, /// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error. break @@ -1070,7 +1146,7 @@ extension TaskHandler: ChannelDuplexHandler { } default: switch self.state { - case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected: + case .idle, .sendingBodyWaitingResponseHead, .sendingBodyResponseHeadReceived, .bodySentWaitingResponseHead, .bodySentResponseHeadReceived, .redirected, .bufferedEnd: self.state = .endOrError self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) case .endOrError: diff --git a/Sources/AsyncHTTPClient/ResponseReadBuffer.swift b/Sources/AsyncHTTPClient/ResponseReadBuffer.swift new file mode 100644 index 000000000..37cf11665 --- /dev/null +++ b/Sources/AsyncHTTPClient/ResponseReadBuffer.swift @@ -0,0 +1,32 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIO +import NIOHTTP1 + +struct ResponseReadBuffer { + private var responseParts: CircularBuffer + + init() { + self.responseParts = CircularBuffer(initialCapacity: 16) + } + + mutating func appendPart(_ part: HTTPClientResponsePart) { + self.responseParts.append(part) + } + + mutating func nextRead() -> HTTPClientResponsePart? { + return self.responseParts.popFirst() + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index df0816b49..f10d99677 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -21,6 +21,7 @@ import NIOHTTP1 import NIOHTTPCompression import NIOSSL import NIOTransportServices +import XCTest /// Are we testing NIO Transport services func isTestingNIOTS() -> Bool { @@ -100,6 +101,52 @@ class CountingDelegate: HTTPClientResponseDelegate { } } +class DelayOnHeadDelegate: HTTPClientResponseDelegate { + typealias Response = ByteBuffer + + let promise: EventLoopPromise + + private var data: ByteBuffer + + private var mayReceiveData = false + + private var expectError = false + + init(promise: EventLoopPromise) { + self.promise = promise + self.data = ByteBuffer() + + self.promise.futureResult.whenSuccess { + self.mayReceiveData = true + } + self.promise.futureResult.whenFailure { (_: Error) in + self.expectError = true + } + } + + func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { + XCTAssertFalse(self.expectError) + return self.promise.futureResult.hop(to: task.eventLoop) + } + + func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { + XCTAssertTrue(self.mayReceiveData) + XCTAssertFalse(self.expectError) + self.data.writeImmutableBuffer(buffer) + return self.promise.futureResult.hop(to: task.eventLoop) + } + + func didFinishRequest(task: HTTPClient.Task) throws -> Response { + XCTAssertTrue(self.mayReceiveData) + XCTAssertFalse(self.expectError) + return self.data + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + XCTAssertTrue(self.expectError) + } +} + internal final class RecordingHandler: ChannelDuplexHandler { typealias InboundIn = Input typealias OutboundIn = Output @@ -464,6 +511,21 @@ internal final class HttpBinHandler: ChannelInboundHandler { context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } + func writeChunked(context: ChannelHandlerContext) { + // This tests receiving chunks very fast: please do not insert delays here! + let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) + + context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + for i in 0..<10 { + let msg = "id: \(i)" + var buf = context.channel.allocator.buffer(capacity: msg.count) + buf.writeString(msg) + context.write(wrapOutboundOut(.body(.byteBuffer(buf))), promise: nil) + } + + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.isServingRequest = true switch self.unwrapInboundIn(data) { @@ -579,6 +641,19 @@ internal final class HttpBinHandler: ChannelInboundHandler { return case "/events/10/content-length": self.writeEvents(context: context, isContentLengthRequired: true) + case "/chunked": + self.writeChunked(context: context) + return + case "/close-on-response": + var headers = self.responseHeaders + headers.replaceOrAdd(name: "connection", value: "close") + + var builder = HTTPResponseBuilder(.http1_1, status: .ok, headers: headers) + builder.body = ByteBuffer(string: "some body content") + + // We're forcing this closed now. + self.shouldClose = true + self.resps.append(builder) default: context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 46b6d2bbf..4344fd72e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -130,6 +130,9 @@ extension HTTPClientTests { ("testWeCloseConnectionsWhenConnectionCloseSetByServer", testWeCloseConnectionsWhenConnectionCloseSetByServer), ("testBiDirectionalStreaming", testBiDirectionalStreaming), ("testSynchronousHandshakeErrorReporting", testSynchronousHandshakeErrorReporting), + ("testFileDownloadChunked", testFileDownloadChunked), + ("testCloseWhileBackpressureIsExertedIsFine", testCloseWhileBackpressureIsExertedIsFine), + ("testErrorAfterCloseWhileBackpressureExerted", testErrorAfterCloseWhileBackpressureExerted), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index ec9a798a9..7e0948aef 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -2844,4 +2844,70 @@ class HTTPClientTests: XCTestCase { } } } + + func testFileDownloadChunked() throws { + var request = try Request(url: self.defaultHTTPBinURLPrefix + "chunked") + request.headers.add(name: "Accept", value: "text/event-stream") + + let progress = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let delegate = try FileDownloadDelegate(path: path) + + let progress = try self.defaultClient.execute( + request: request, + delegate: delegate + ) + .wait() + + try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) + + return progress + } + + XCTAssertEqual(nil, progress.totalBytes) + XCTAssertEqual(50, progress.receivedBytes) + } + + func testCloseWhileBackpressureIsExertedIsFine() throws { + let request = try Request(url: self.defaultHTTPBinURLPrefix + "close-on-response") + let backpressurePromise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + + let resultFuture = self.defaultClient.execute( + request: request, delegate: DelayOnHeadDelegate(promise: backpressurePromise) + ) + + self.defaultClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(50)) { + backpressurePromise.succeed(()) + } + + // The full response must be correctly delivered. + var data = try resultFuture.wait() + guard let info = try data.readJSONDecodable(RequestInfo.self, length: data.readableBytes) else { + XCTFail("Could not parse response") + return + } + XCTAssertEqual(info.data, "some body content") + } + + func testErrorAfterCloseWhileBackpressureExerted() throws { + enum ExpectedError: Error { + case expected + } + + let request = try Request(url: self.defaultHTTPBinURLPrefix + "close-on-response") + let backpressurePromise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + + let resultFuture = self.defaultClient.execute( + request: request, delegate: DelayOnHeadDelegate(promise: backpressurePromise) + ) + + self.defaultClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(50)) { + backpressurePromise.fail(ExpectedError.expected) + } + + // The task must be failed. + XCTAssertThrowsError(try resultFuture.wait()) { error in + XCTAssertEqual(error as? ExpectedError, .expected) + } + } } diff --git a/scripts/soundness.sh b/scripts/soundness.sh index bc972b760..0e0c5221c 100755 --- a/scripts/soundness.sh +++ b/scripts/soundness.sh @@ -18,7 +18,7 @@ here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" function replace_acceptable_years() { # this needs to replace all acceptable forms with 'YEARS' - sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/' + sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/2019/YEARS/' -e 's/2020/YEARS/' -e 's/2021/YEARS/' } printf "=> Checking linux tests... "