Skip to content

Commit 91b2640

Browse files
carolinacassLukasa
andauthored
Update collect to use content-length to make early checks
Motivation: not accumulate too many bytes Modifications: Implementing collect function to use NIOCore version to prevent overflowing Co-authored-by: Cory Benfield <[email protected]>
1 parent 9cdc429 commit 91b2640

10 files changed

+175
-33
lines changed

Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift

+56-15
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,6 @@ public struct HTTPClientResponse: Sendable {
3232
/// The body of this HTTP response.
3333
public var body: Body
3434

35-
init(
36-
bag: Transaction,
37-
version: HTTPVersion,
38-
status: HTTPResponseStatus,
39-
headers: HTTPHeaders
40-
) {
41-
self.version = version
42-
self.status = status
43-
self.headers = headers
44-
self.body = Body(TransactionBody(bag))
45-
}
46-
4735
@inlinable public init(
4836
version: HTTPVersion = .http1_1,
4937
status: HTTPResponseStatus = .ok,
@@ -55,6 +43,17 @@ public struct HTTPClientResponse: Sendable {
5543
self.headers = headers
5644
self.body = body
5745
}
46+
47+
init(
48+
bag: Transaction,
49+
version: HTTPVersion,
50+
status: HTTPResponseStatus,
51+
headers: HTTPHeaders,
52+
requestMethod: HTTPMethod
53+
) {
54+
let contentLength = HTTPClientResponse.expectedContentLength(requestMethod: requestMethod, headers: headers, status: status)
55+
self.init(version: version, status: status, headers: headers, body: .init(TransactionBody(bag, expectedContentLength: contentLength)))
56+
}
5857
}
5958

6059
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
@@ -83,6 +82,48 @@ extension HTTPClientResponse {
8382
@inlinable public func makeAsyncIterator() -> AsyncIterator {
8483
.init(storage: self.storage.makeAsyncIterator())
8584
}
85+
86+
@inlinable init(storage: Storage) {
87+
self.storage = storage
88+
}
89+
90+
/// Accumulates `Body` of ``ByteBuffer``s into a single ``ByteBuffer``.
91+
/// - Parameters:
92+
/// - maxBytes: The maximum number of bytes this method is allowed to accumulate
93+
/// - Throws: `NIOTooManyBytesError` if the the sequence contains more than `maxBytes`.
94+
/// - Returns: the number of bytes collected over time
95+
@inlinable public func collect(upTo maxBytes: Int) async throws -> ByteBuffer {
96+
switch self.storage {
97+
case .transaction(let transactionBody):
98+
if let contentLength = transactionBody.expectedContentLength {
99+
if contentLength > maxBytes {
100+
throw NIOTooManyBytesError()
101+
}
102+
}
103+
case .anyAsyncSequence:
104+
break
105+
}
106+
107+
/// calling collect function within here in order to ensure the correct nested type
108+
func collect<Body: AsyncSequence>(_ body: Body, maxBytes: Int) async throws -> ByteBuffer where Body.Element == ByteBuffer {
109+
try await body.collect(upTo: maxBytes)
110+
}
111+
return try await collect(self, maxBytes: maxBytes)
112+
}
113+
}
114+
}
115+
116+
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
117+
extension HTTPClientResponse {
118+
static func expectedContentLength(requestMethod: HTTPMethod, headers: HTTPHeaders, status: HTTPResponseStatus) -> Int? {
119+
if status == .notModified {
120+
return 0
121+
} else if requestMethod == .HEAD {
122+
return 0
123+
} else {
124+
let contentLength = headers["content-length"].first.flatMap { Int($0, radix: 10) }
125+
return contentLength
126+
}
86127
}
87128
}
88129

@@ -132,10 +173,10 @@ extension HTTPClientResponse.Body.Storage.AsyncIterator: AsyncIteratorProtocol {
132173
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
133174
extension HTTPClientResponse.Body {
134175
init(_ body: TransactionBody) {
135-
self.init(.transaction(body))
176+
self.init(storage: .transaction(body))
136177
}
137178

138-
@usableFromInline init(_ storage: Storage) {
179+
@inlinable init(_ storage: Storage) {
139180
self.storage = storage
140181
}
141182

@@ -146,7 +187,7 @@ extension HTTPClientResponse.Body {
146187
@inlinable public static func stream<SequenceOfBytes>(
147188
_ sequenceOfBytes: SequenceOfBytes
148189
) -> Self where SequenceOfBytes: AsyncSequence & Sendable, SequenceOfBytes.Element == ByteBuffer {
149-
self.init(.anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition)))
190+
Self(storage: .anyAsyncSequence(AnyAsyncSequence(sequenceOfBytes.singleIteratorPrecondition)))
150191
}
151192

152193
public static func bytes(_ byteBuffer: ByteBuffer) -> Self {

Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ extension Transaction: HTTPExecutableRequest {
236236
bag: self,
237237
version: head.version,
238238
status: head.status,
239-
headers: head.headers
239+
headers: head.headers,
240+
requestMethod: self.requestHead.method
240241
)
241242
continuation.resume(returning: asyncResponse)
242243
}

Sources/AsyncHTTPClient/AsyncAwait/TransactionBody.swift

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ import NIOCore
2020
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
2121
@usableFromInline final class TransactionBody: Sendable {
2222
@usableFromInline let transaction: Transaction
23+
@usableFromInline let expectedContentLength: Int?
2324

24-
init(_ transaction: Transaction) {
25+
init(_ transaction: Transaction, expectedContentLength: Int?) {
2526
self.transaction = transaction
27+
self.expectedContentLength = expectedContentLength
2628
}
2729

2830
deinit {

Sources/AsyncHTTPClient/HTTPClient.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ public class HTTPClient {
645645
"ahc-el-preference": "\(eventLoopPreference)"])
646646

647647
let failedTask: Task<Delegate.Response>? = self.stateLock.withLock {
648-
switch state {
648+
switch self.state {
649649
case .upAndRunning:
650650
return nil
651651
case .shuttingDown, .shutDown:

Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ extension AsyncAwaitEndToEndTests {
5151
("testRejectsInvalidCharactersInHeaderFieldValues_http1", testRejectsInvalidCharactersInHeaderFieldValues_http1),
5252
("testRejectsInvalidCharactersInHeaderFieldValues_http2", testRejectsInvalidCharactersInHeaderFieldValues_http2),
5353
("testUsingGetMethodInsteadOfWait", testUsingGetMethodInsteadOfWait),
54+
("testSimpleContentLengthErrorNoBody", testSimpleContentLengthErrorNoBody),
5455
]
5556
}
5657
}

Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift

+26-15
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
114114
) else { return }
115115
XCTAssertEqual(response.headers["content-length"], ["4"])
116116
guard let body = await XCTAssertNoThrowWithResult(
117-
try await response.body.collect()
117+
try await response.body.collect(upTo: 1024)
118118
) else { return }
119119
XCTAssertEqual(body, ByteBuffer(string: "1234"))
120120
}
@@ -137,7 +137,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
137137
) else { return }
138138
XCTAssertEqual(response.headers["content-length"], [])
139139
guard let body = await XCTAssertNoThrowWithResult(
140-
try await response.body.collect()
140+
try await response.body.collect(upTo: 1024)
141141
) else { return }
142142
XCTAssertEqual(body, ByteBuffer(string: "1234"))
143143
}
@@ -160,7 +160,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
160160
) else { return }
161161
XCTAssertEqual(response.headers["content-length"], [])
162162
guard let body = await XCTAssertNoThrowWithResult(
163-
try await response.body.collect()
163+
try await response.body.collect(upTo: 1024)
164164
) else { return }
165165
XCTAssertEqual(body, ByteBuffer(string: "1234"))
166166
}
@@ -183,7 +183,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
183183
) else { return }
184184
XCTAssertEqual(response.headers["content-length"], ["4"])
185185
guard let body = await XCTAssertNoThrowWithResult(
186-
try await response.body.collect()
186+
try await response.body.collect(upTo: 1024)
187187
) else { return }
188188
XCTAssertEqual(body, ByteBuffer(string: "1234"))
189189
}
@@ -210,7 +210,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
210210
) else { return }
211211
XCTAssertEqual(response.headers["content-length"], [])
212212
guard let body = await XCTAssertNoThrowWithResult(
213-
try await response.body.collect()
213+
try await response.body.collect(upTo: 1024)
214214
) else { return }
215215
XCTAssertEqual(body, ByteBuffer(string: "1234"))
216216
}
@@ -233,7 +233,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
233233
) else { return }
234234
XCTAssertEqual(response.headers["content-length"], [])
235235
guard let body = await XCTAssertNoThrowWithResult(
236-
try await response.body.collect()
236+
try await response.body.collect(upTo: 1024)
237237
) else { return }
238238
XCTAssertEqual(body, ByteBuffer(string: "1234"))
239239
}
@@ -580,7 +580,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
580580
) else {
581581
return
582582
}
583-
guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect()) else { return }
583+
guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return }
584584
var maybeRequestInfo: RequestInfo?
585585
XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body))
586586
guard let requestInfo = maybeRequestInfo else { return }
@@ -641,7 +641,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
641641
) else { return }
642642
XCTAssertEqual(response1.headers["content-length"], [])
643643
guard let body = await XCTAssertNoThrowWithResult(
644-
try await response1.body.collect()
644+
try await response1.body.collect(upTo: 1024)
645645
) else { return }
646646
XCTAssertEqual(body, ByteBuffer(string: "1234"))
647647

@@ -650,7 +650,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
650650
) else { return }
651651
XCTAssertEqual(response2.headers["content-length"], [])
652652
guard let body = await XCTAssertNoThrowWithResult(
653-
try await response2.body.collect()
653+
try await response2.body.collect(upTo: 1024)
654654
) else { return }
655655
XCTAssertEqual(body, ByteBuffer(string: "1234"))
656656
}
@@ -803,13 +803,24 @@ final class AsyncAwaitEndToEndTests: XCTestCase {
803803
XCTAssertEqual(response.version, .http2)
804804
}
805805
}
806-
}
807806

808-
extension AsyncSequence where Element == ByteBuffer {
809-
func collect() async rethrows -> ByteBuffer {
810-
try await self.reduce(into: ByteBuffer()) { accumulatingBuffer, nextBuffer in
811-
var nextBuffer = nextBuffer
812-
accumulatingBuffer.writeBuffer(&nextBuffer)
807+
func testSimpleContentLengthErrorNoBody() {
808+
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
809+
XCTAsyncTest {
810+
let bin = HTTPBin(.http2(compress: false))
811+
defer { XCTAssertNoThrow(try bin.shutdown()) }
812+
let client = makeDefaultHTTPClient()
813+
defer { XCTAssertNoThrow(try client.syncShutdown()) }
814+
let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:))
815+
let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/content-length-without-body")
816+
guard let response = await XCTAssertNoThrowWithResult(
817+
try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)
818+
) else { return }
819+
await XCTAssertThrowsError(
820+
try await response.body.collect(upTo: 3)
821+
) {
822+
XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError())
823+
}
813824
}
814825
}
815826
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the AsyncHTTPClient open source project
4+
//
5+
// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
//
15+
// HTTPClientResponseTests+XCTest.swift
16+
//
17+
import XCTest
18+
19+
///
20+
/// NOTE: This file was generated by generate_linux_tests.rb
21+
///
22+
/// Do NOT edit this file directly as it will be regenerated automatically when needed.
23+
///
24+
25+
extension HTTPClientResponseTests {
26+
static var allTests: [(String, (HTTPClientResponseTests) -> () throws -> Void)] {
27+
return [
28+
("testSimpleResponse", testSimpleResponse),
29+
("testSimpleResponseNotModified", testSimpleResponseNotModified),
30+
("testSimpleResponseHeadRequestMethod", testSimpleResponseHeadRequestMethod),
31+
("testResponseNoContentLengthHeader", testResponseNoContentLengthHeader),
32+
("testResponseInvalidInteger", testResponseInvalidInteger),
33+
]
34+
}
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the AsyncHTTPClient open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
@testable import AsyncHTTPClient
16+
import Logging
17+
import NIOCore
18+
import XCTest
19+
20+
final class HTTPClientResponseTests: XCTestCase {
21+
func testSimpleResponse() {
22+
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .ok)
23+
XCTAssertEqual(response, 1025)
24+
}
25+
26+
func testSimpleResponseNotModified() {
27+
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .notModified)
28+
XCTAssertEqual(response, 0)
29+
}
30+
31+
func testSimpleResponseHeadRequestMethod() {
32+
let response = HTTPClientResponse.expectedContentLength(requestMethod: .HEAD, headers: ["content-length": "1025"], status: .ok)
33+
XCTAssertEqual(response, 0)
34+
}
35+
36+
func testResponseNoContentLengthHeader() {
37+
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: [:], status: .ok)
38+
XCTAssertEqual(response, nil)
39+
}
40+
41+
func testResponseInvalidInteger() {
42+
let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "none"], status: .ok)
43+
XCTAssertEqual(response, nil)
44+
}
45+
}

Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+5
Original file line numberDiff line numberDiff line change
@@ -945,6 +945,11 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
945945
// We're forcing this closed now.
946946
self.shouldClose = true
947947
self.resps.append(builder)
948+
case "/content-length-without-body":
949+
var headers = self.responseHeaders
950+
headers.replaceOrAdd(name: "content-length", value: "1234")
951+
context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil)
952+
return
948953
default:
949954
context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil)
950955
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)

Tests/LinuxMain.swift

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct LinuxMain {
4444
testCase(HTTPClientNIOTSTests.allTests),
4545
testCase(HTTPClientReproTests.allTests),
4646
testCase(HTTPClientRequestTests.allTests),
47+
testCase(HTTPClientResponseTests.allTests),
4748
testCase(HTTPClientSOCKSTests.allTests),
4849
testCase(HTTPClientTests.allTests),
4950
testCase(HTTPClientUncleanSSLConnectionShutdownTests.allTests),

0 commit comments

Comments
 (0)