diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift index b68e8db8b..07904b681 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift @@ -32,18 +32,6 @@ public struct HTTPClientResponse: Sendable { /// The body of this HTTP response. public var body: Body - init( - bag: Transaction, - version: HTTPVersion, - status: HTTPResponseStatus, - headers: HTTPHeaders - ) { - self.version = version - self.status = status - self.headers = headers - self.body = Body(TransactionBody(bag)) - } - @inlinable public init( version: HTTPVersion = .http1_1, status: HTTPResponseStatus = .ok, @@ -55,6 +43,17 @@ public struct HTTPClientResponse: Sendable { self.headers = headers self.body = body } + + init( + bag: Transaction, + version: HTTPVersion, + status: HTTPResponseStatus, + headers: HTTPHeaders, + requestMethod: HTTPMethod + ) { + let contentLength = HTTPClientResponse.expectedContentLength(requestMethod: requestMethod, headers: headers, status: status) + self.init(version: version, status: status, headers: headers, body: .init(TransactionBody(bag, expectedContentLength: contentLength))) + } } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) @@ -83,6 +82,48 @@ extension HTTPClientResponse { @inlinable public func makeAsyncIterator() -> AsyncIterator { .init(storage: self.storage.makeAsyncIterator()) } + + @inlinable init(storage: Storage) { + self.storage = storage + } + + /// Accumulates `Body` of ``ByteBuffer``s into a single ``ByteBuffer``. + /// - Parameters: + /// - maxBytes: The maximum number of bytes this method is allowed to accumulate + /// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`. + /// - Returns: the number of bytes collected over time + @inlinable public func collect(upTo maxBytes: Int) async throws -> ByteBuffer { + switch self.storage { + case .transaction(let transactionBody): + if let contentLength = transactionBody.expectedContentLength { + if contentLength > maxBytes { + throw NIOTooManyBytesError() + } + } + case .anyAsyncSequence: + break + } + + /// calling collect function within here in order to ensure the correct nested type + func collect(_ body: Body, maxBytes: Int) async throws -> ByteBuffer where Body.Element == ByteBuffer { + try await body.collect(upTo: maxBytes) + } + return try await collect(self, maxBytes: maxBytes) + } + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientResponse { + static func expectedContentLength(requestMethod: HTTPMethod, headers: HTTPHeaders, status: HTTPResponseStatus) -> Int? { + if status == .notModified { + return 0 + } else if requestMethod == .HEAD { + return 0 + } else { + let contentLength = headers["content-length"].first.flatMap { Int($0, radix: 10) } + return contentLength + } } } @@ -132,10 +173,10 @@ extension HTTPClientResponse.Body.Storage.AsyncIterator: AsyncIteratorProtocol { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientResponse.Body { init(_ body: TransactionBody) { - self.init(.transaction(body)) + self.init(storage: .transaction(body)) } - @usableFromInline init(_ storage: Storage) { + @inlinable init(_ storage: Storage) { self.storage = storage } @@ -146,7 +187,7 @@ extension HTTPClientResponse.Body { @inlinable public static func stream( _ sequenceOfBytes: SequenceOfBytes ) -> Self where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == ByteBuffer { - self.init(.anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition))) + Self(storage: .anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition))) } public static func bytes(_ byteBuffer: ByteBuffer) -> Self { diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index d81fbfd28..8846f36a5 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -236,7 +236,8 @@ extension Transaction: HTTPExecutableRequest { bag: self, version: head.version, status: head.status, - headers: head.headers + headers: head.headers, + requestMethod: self.requestHead.method ) continuation.resume(returning: asyncResponse) } diff --git a/Sources/AsyncHTTPClient/AsyncAwait/TransactionBody.swift b/Sources/AsyncHTTPClient/AsyncAwait/TransactionBody.swift index 497a3cc72..23a8e505e 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/TransactionBody.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/TransactionBody.swift @@ -20,9 +20,11 @@ import NIOCore @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) @usableFromInline final class TransactionBody: Sendable { @usableFromInline let transaction: Transaction + @usableFromInline let expectedContentLength: Int? - init(_ transaction: Transaction) { + init(_ transaction: Transaction, expectedContentLength: Int?) { self.transaction = transaction + self.expectedContentLength = expectedContentLength } deinit { diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index beb2ea458..9f25cc15c 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -645,7 +645,7 @@ public class HTTPClient { "ahc-el-preference": "\(eventLoopPreference)"]) let failedTask: Task? = self.stateLock.withLock { - switch state { + switch self.state { case .upAndRunning: return nil case .shuttingDown, .shutDown: diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift index ce0e2846d..85ac04f5c 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift @@ -51,6 +51,7 @@ extension AsyncAwaitEndToEndTests { ("testRejectsInvalidCharactersInHeaderFieldValues_http1", testRejectsInvalidCharactersInHeaderFieldValues_http1), ("testRejectsInvalidCharactersInHeaderFieldValues_http2", testRejectsInvalidCharactersInHeaderFieldValues_http2), ("testUsingGetMethodInsteadOfWait", testUsingGetMethodInsteadOfWait), + ("testSimpleContentLengthErrorNoBody", testSimpleContentLengthErrorNoBody), ] } } diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift index e80957079..af99fb0a4 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift @@ -114,7 +114,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() + try await response.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -137,7 +137,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response.headers["content-length"], []) guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() + try await response.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -160,7 +160,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response.headers["content-length"], []) guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() + try await response.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -183,7 +183,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() + try await response.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -210,7 +210,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response.headers["content-length"], []) guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() + try await response.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -233,7 +233,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response.headers["content-length"], []) guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect() + try await response.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -580,7 +580,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } - guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect()) else { return } + guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return } var maybeRequestInfo: RequestInfo? XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body)) guard let requestInfo = maybeRequestInfo else { return } @@ -641,7 +641,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response1.headers["content-length"], []) guard let body = await XCTAssertNoThrowWithResult( - try await response1.body.collect() + try await response1.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) @@ -650,7 +650,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { ) else { return } XCTAssertEqual(response2.headers["content-length"], []) guard let body = await XCTAssertNoThrowWithResult( - try await response2.body.collect() + try await response2.body.collect(upTo: 1024) ) else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } @@ -803,13 +803,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertEqual(response.version, .http2) } } -} -extension AsyncSequence where Element == ByteBuffer { - func collect() async rethrows -> ByteBuffer { - try await self.reduce(into: ByteBuffer()) { accumulatingBuffer, nextBuffer in - var nextBuffer = nextBuffer - accumulatingBuffer.writeBuffer(&nextBuffer) + func testSimpleContentLengthErrorNoBody() { + guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return } + XCTAsyncTest { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/content-length-without-body") + guard let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) else { return } + await XCTAssertThrowsError( + try await response.body.collect(upTo: 3) + ) { + XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError()) + } } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientResponseTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests+XCTest.swift new file mode 100644 index 000000000..0a1a7cab6 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests+XCTest.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTPClientResponseTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTPClientResponseTests { + static var allTests: [(String, (HTTPClientResponseTests) -> () throws -> Void)] { + return [ + ("testSimpleResponse", testSimpleResponse), + ("testSimpleResponseNotModified", testSimpleResponseNotModified), + ("testSimpleResponseHeadRequestMethod", testSimpleResponseHeadRequestMethod), + ("testResponseNoContentLengthHeader", testResponseNoContentLengthHeader), + ("testResponseInvalidInteger", testResponseInvalidInteger), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift new file mode 100644 index 000000000..bf0ecfeb9 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift @@ -0,0 +1,45 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +@testable import AsyncHTTPClient +import Logging +import NIOCore +import XCTest + +final class HTTPClientResponseTests: XCTestCase { + func testSimpleResponse() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .ok) + XCTAssertEqual(response, 1025) + } + + func testSimpleResponseNotModified() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .notModified) + XCTAssertEqual(response, 0) + } + + func testSimpleResponseHeadRequestMethod() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .HEAD, headers: ["content-length": "1025"], status: .ok) + XCTAssertEqual(response, 0) + } + + func testResponseNoContentLengthHeader() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: [:], status: .ok) + XCTAssertEqual(response, nil) + } + + func testResponseInvalidInteger() { + let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "none"], status: .ok) + XCTAssertEqual(response, nil) + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index c617555c6..1a3cbd968 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -945,6 +945,11 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // We're forcing this closed now. self.shouldClose = true self.resps.append(builder) + case "/content-length-without-body": + var headers = self.responseHeaders + headers.replaceOrAdd(name: "content-length", value: "1234") + context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + return default: context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index ca8478326..886bf2b95 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -44,6 +44,7 @@ struct LinuxMain { testCase(HTTPClientNIOTSTests.allTests), testCase(HTTPClientReproTests.allTests), testCase(HTTPClientRequestTests.allTests), + testCase(HTTPClientResponseTests.allTests), testCase(HTTPClientSOCKSTests.allTests), testCase(HTTPClientTests.allTests), testCase(HTTPClientUncleanSSLConnectionShutdownTests.allTests),