Skip to content

Collect function fix #672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 56 additions & 15 deletions Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,6 @@ public struct HTTPClientResponse: Sendable {
/// The body of this HTTP response.
public var body: Body

init(
bag: Transaction,
version: HTTPVersion,
status: HTTPResponseStatus,
headers: HTTPHeaders
) {
self.version = version
self.status = status
self.headers = headers
self.body = Body(TransactionBody(bag))
}

@inlinable public init(
version: HTTPVersion = .http1_1,
status: HTTPResponseStatus = .ok,
Expand All @@ -55,6 +43,17 @@ public struct HTTPClientResponse: Sendable {
self.headers = headers
self.body = body
}

init(
bag: Transaction,
version: HTTPVersion,
status: HTTPResponseStatus,
headers: HTTPHeaders,
requestMethod: HTTPMethod
) {
let contentLength = HTTPClientResponse.expectedContentLength(requestMethod: requestMethod, headers: headers, status: status)
self.init(version: version, status: status, headers: headers, body: .init(TransactionBody(bag, expectedContentLength: contentLength)))
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
Expand Down Expand Up @@ -83,6 +82,48 @@ extension HTTPClientResponse {
@inlinable public func makeAsyncIterator() -> AsyncIterator {
.init(storage: self.storage.makeAsyncIterator())
}

@inlinable init(storage: Storage) {
self.storage = storage
}

/// Accumulates `Body` of ``ByteBuffer``s into a single ``ByteBuffer``.
/// - Parameters:
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
/// - Returns: the number of bytes collected over time
@inlinable public func collect(upTo maxBytes: Int) async throws -> ByteBuffer {
switch self.storage {
case .transaction(let transactionBody):
if let contentLength = transactionBody.expectedContentLength {
if contentLength > maxBytes {
throw NIOTooManyBytesError()
}
}
case .anyAsyncSequence:
break
}

/// calling collect function within here in order to ensure the correct nested type
func collect<Body: AsyncSequence>(_ body: Body, maxBytes: Int) async throws -> ByteBuffer where Body.Element == ByteBuffer {
try await body.collect(upTo: maxBytes)
}
return try await collect(self, maxBytes: maxBytes)
}
}
}

@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientResponse {
static func expectedContentLength(requestMethod: HTTPMethod, headers: HTTPHeaders, status: HTTPResponseStatus) -> Int? {
if status == .notModified {
return 0
} else if requestMethod == .HEAD {
return 0
} else {
let contentLength = headers["content-length"].first.flatMap { Int($0, radix: 10) }
return contentLength
}
}
}

Expand Down Expand Up @@ -132,10 +173,10 @@ extension HTTPClientResponse.Body.Storage.AsyncIterator: AsyncIteratorProtocol {
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
extension HTTPClientResponse.Body {
init(_ body: TransactionBody) {
self.init(.transaction(body))
self.init(storage: .transaction(body))
}

@usableFromInline init(_ storage: Storage) {
@inlinable init(_ storage: Storage) {
self.storage = storage
}

Expand All @@ -146,7 +187,7 @@ extension HTTPClientResponse.Body {
@inlinable public static func stream<SequenceOfBytes>(
_ sequenceOfBytes: SequenceOfBytes
) -> Self where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == ByteBuffer {
self.init(.anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition)))
Self(storage: .anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition)))
}

public static func bytes(_ byteBuffer: ByteBuffer) -> Self {
Expand Down
3 changes: 2 additions & 1 deletion Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ extension Transaction: HTTPExecutableRequest {
bag: self,
version: head.version,
status: head.status,
headers: head.headers
headers: head.headers,
requestMethod: self.requestHead.method
)
continuation.resume(returning: asyncResponse)
}
Expand Down
4 changes: 3 additions & 1 deletion Sources/AsyncHTTPClient/AsyncAwait/TransactionBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import NIOCore
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
@usableFromInline final class TransactionBody: Sendable {
@usableFromInline let transaction: Transaction
@usableFromInline let expectedContentLength: Int?

init(_ transaction: Transaction) {
init(_ transaction: Transaction, expectedContentLength: Int?) {
self.transaction = transaction
self.expectedContentLength = expectedContentLength
}

deinit {
Expand Down
2 changes: 1 addition & 1 deletion Sources/AsyncHTTPClient/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,7 @@ public class HTTPClient {
"ahc-el-preference": "\(eventLoopPreference)"])

let failedTask: Task<Delegate.Response>? = self.stateLock.withLock {
switch state {
switch self.state {
case .upAndRunning:
return nil
case .shuttingDown, .shutDown:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ extension AsyncAwaitEndToEndTests {
("testRejectsInvalidCharactersInHeaderFieldValues_http1", testRejectsInvalidCharactersInHeaderFieldValues_http1),
("testRejectsInvalidCharactersInHeaderFieldValues_http2", testRejectsInvalidCharactersInHeaderFieldValues_http2),
("testUsingGetMethodInsteadOfWait", testUsingGetMethodInsteadOfWait),
("testSimpleContentLengthErrorNoBody", testSimpleContentLengthErrorNoBody),
]
}
}
41 changes: 26 additions & 15 deletions Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response.headers["content-length"], ["4"])
guard let body = await XCTAssertNoThrowWithResult(
try await response.body.collect()
try await response.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand All @@ -137,7 +137,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response.body.collect()
try await response.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand All @@ -160,7 +160,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response.body.collect()
try await response.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand All @@ -183,7 +183,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response.headers["content-length"], ["4"])
guard let body = await XCTAssertNoThrowWithResult(
try await response.body.collect()
try await response.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand All @@ -210,7 +210,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response.body.collect()
try await response.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand All @@ -233,7 +233,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response.body.collect()
try await response.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand Down Expand Up @@ -580,7 +580,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else {
return
}
guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect()) else { return }
guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return }
var maybeRequestInfo: RequestInfo?
XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body))
guard let requestInfo = maybeRequestInfo else { return }
Expand Down Expand Up @@ -641,7 +641,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response1.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response1.body.collect()
try await response1.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))

Expand All @@ -650,7 +650,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
) else { return }
XCTAssertEqual(response2.headers["content-length"], [])
guard let body = await XCTAssertNoThrowWithResult(
try await response2.body.collect()
try await response2.body.collect(upTo: 1024)
) else { return }
XCTAssertEqual(body, ByteBuffer(string: "1234"))
}
Expand Down Expand Up @@ -803,13 +803,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
XCTAssertEqual(response.version, .http2)
}
}
}

extension AsyncSequence where Element == ByteBuffer {
func collect() async rethrows -> ByteBuffer {
try await self.reduce(into: ByteBuffer()) { accumulatingBuffer, nextBuffer in
var nextBuffer = nextBuffer
accumulatingBuffer.writeBuffer(&nextBuffer)
func testSimpleContentLengthErrorNoBody() {
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
XCTAsyncTest {
let bin = HTTPBin(.http2(compress: false))
defer { XCTAssertNoThrow(try bin.shutdown()) }
let client = makeDefaultHTTPClient()
defer { XCTAssertNoThrow(try client.syncShutdown()) }
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/content-length-without-body")
guard let response = await XCTAssertNoThrowWithResult(
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
) else { return }
await XCTAssertThrowsError(
try await response.body.collect(upTo: 3)
) {
XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError())
}
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientResponseTests+XCTest.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the AsyncHTTPClient open source project
//
// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
//
// HTTPClientResponseTests+XCTest.swift
//
import XCTest

///
/// NOTE: This file was generated by generate_linux_tests.rb
///
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
///

extension HTTPClientResponseTests {
static var allTests: [(String, (HTTPClientResponseTests) -> () throws -> Void)] {
return [
("testSimpleResponse", testSimpleResponse),
("testSimpleResponseNotModified", testSimpleResponseNotModified),
("testSimpleResponseHeadRequestMethod", testSimpleResponseHeadRequestMethod),
("testResponseNoContentLengthHeader", testResponseNoContentLengthHeader),
("testResponseInvalidInteger", testResponseInvalidInteger),
]
}
}
45 changes: 45 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the AsyncHTTPClient open source project
//
// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@testable import AsyncHTTPClient
import Logging
import NIOCore
import XCTest

final class HTTPClientResponseTests: XCTestCase {
func testSimpleResponse() {
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .ok)
XCTAssertEqual(response, 1025)
}

func testSimpleResponseNotModified() {
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .notModified)
XCTAssertEqual(response, 0)
}

func testSimpleResponseHeadRequestMethod() {
let response = HTTPClientResponse.expectedContentLength(requestMethod: .HEAD, headers: ["content-length": "1025"], status: .ok)
XCTAssertEqual(response, 0)
}

func testResponseNoContentLengthHeader() {
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: [:], status: .ok)
XCTAssertEqual(response, nil)
}

func testResponseInvalidInteger() {
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "none"], status: .ok)
XCTAssertEqual(response, nil)
}
}
5 changes: 5 additions & 0 deletions Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,11 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
// We're forcing this closed now.
self.shouldClose = true
self.resps.append(builder)
case "/content-length-without-body":
var headers = self.responseHeaders
headers.replaceOrAdd(name: "content-length", value: "1234")
context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil)
return
default:
context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
Expand Down
1 change: 1 addition & 0 deletions Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct LinuxMain {
testCase(HTTPClientNIOTSTests.allTests),
testCase(HTTPClientReproTests.allTests),
testCase(HTTPClientRequestTests.allTests),
testCase(HTTPClientResponseTests.allTests),
testCase(HTTPClientSOCKSTests.allTests),
testCase(HTTPClientTests.allTests),
testCase(HTTPClientUncleanSSLConnectionShutdownTests.allTests),
Expand Down