diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index ee0f2613e..faa31d06d 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -925,6 +925,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { case uncleanShutdown case traceRequestWithBody case invalidHeaderFieldNames([String]) + case bodyLengthMismatch } private var code: Code @@ -969,10 +970,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached) /// Redirect Cycle detected. public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected) - /// Unclean shutdown + /// Unclean shutdown. public static let uncleanShutdown = HTTPClientError(code: .uncleanShutdown) - /// A body was sent in a request with method TRACE + /// A body was sent in a request with method TRACE. public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody) - /// Header field names contain invalid characters + /// Header field names contain invalid characters. public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) } + /// Body length is not equal to `Content-Length`. + public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 733bfbb2b..3d5231085 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -641,7 +641,7 @@ internal class TaskHandler: RemovableChann case head case redirected(HTTPResponseHead, URL) case body - case end + case endOrError } let task: HTTPClient.Task @@ -651,6 +651,8 @@ internal class TaskHandler: RemovableChann let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request. var state: State = .idle + var expectedBodyLength: Int? + var actualBodyLength: Int = 0 var pendingRead = false var mayRead = true var closing = false { @@ -785,7 +787,7 @@ extension TaskHandler: ChannelDuplexHandler { } catch { promise?.fail(error) self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) - self.state = .end + self.state = .endOrError return } @@ -799,12 +801,23 @@ extension TaskHandler: ChannelDuplexHandler { assert(head.version == HTTPVersion(major: 1, minor: 1), "Sending a request in HTTP version \(head.version) which is unsupported by the above `if`") + let contentLengths = head.headers[canonicalForm: "content-length"] + assert(contentLengths.count <= 1) + + self.expectedBodyLength = contentLengths.first.flatMap { Int($0) } + context.write(wrapOutboundOut(.head(head))).map { self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead) }.flatMap { self.writeBody(request: request, context: context) }.flatMap { context.eventLoop.assertInEventLoop() + if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength { + self.state = .endOrError + let error = HTTPClientError.bodyLengthMismatch + self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) + return context.eventLoop.makeFailedFuture(error) + } return context.writeAndFlush(self.wrapOutboundOut(.end(nil))) }.map { context.eventLoop.assertInEventLoop() @@ -813,10 +826,10 @@ extension TaskHandler: ChannelDuplexHandler { }.flatMapErrorThrowing { error in context.eventLoop.assertInEventLoop() switch self.state { - case .end: + case .endOrError: break default: - self.state = .end + self.state = .endOrError self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } throw error @@ -833,9 +846,11 @@ extension TaskHandler: ChannelDuplexHandler { let promise = self.task.eventLoop.makePromise(of: Void.self) // All writes have to be switched to the channel EL if channel and task ELs differ if context.eventLoop.inEventLoop { + self.actualBodyLength += part.readableBytes context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) } else { context.eventLoop.execute { + self.actualBodyLength += part.readableBytes context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) } } @@ -898,12 +913,12 @@ extension TaskHandler: ChannelDuplexHandler { case .end: switch self.state { case .redirected(let head, let redirectURL): - self.state = .end + self.state = .endOrError self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess { self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise) } default: - self.state = .end + self.state = .endOrError self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest) } } @@ -918,14 +933,14 @@ extension TaskHandler: ChannelDuplexHandler { context.read() } case .failure(let error): - self.state = .end + self.state = .endOrError self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { if (event as? IdleStateHandler.IdleStateEvent) == .read { - self.state = .end + self.state = .endOrError let error = HTTPClientError.readTimeout self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } else { @@ -935,7 +950,7 @@ extension TaskHandler: ChannelDuplexHandler { func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise?) { if (event as? TaskCancelEvent) != nil { - self.state = .end + self.state = .endOrError let error = HTTPClientError.cancelled self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) promise?.succeed(()) @@ -946,10 +961,10 @@ extension TaskHandler: ChannelDuplexHandler { func channelInactive(context: ChannelHandlerContext) { switch self.state { - case .end: + case .endOrError: break case .body, .head, .idle, .redirected, .sent: - self.state = .end + self.state = .endOrError let error = HTTPClientError.remoteConnectionClosed self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } @@ -960,7 +975,7 @@ extension TaskHandler: ChannelDuplexHandler { switch error { case NIOSSLError.uncleanShutdown: switch self.state { - case .end: + case .endOrError: /// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection, /// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error. break @@ -969,11 +984,11 @@ extension TaskHandler: ChannelDuplexHandler { /// We can also ignore this error like `.end`. break default: - self.state = .end + self.state = .endOrError self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } default: - self.state = .end + self.state = .endOrError self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index f0b841fdc..f8b0bd59c 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -949,8 +949,8 @@ class HTTPClientInternalTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) - XCTAssertNoThrow(try elg.syncShutdownGracefully()) XCTAssertNoThrow(try httpBin.shutdown()) + XCTAssertNoThrow(try elg.syncShutdownGracefully()) } let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)//get") diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index a7eee57b5..3d0e279d9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -188,6 +188,7 @@ internal final class HTTPBin { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) let serverChannel: Channel let isShutdown: NIOAtomic = .makeAtomic(value: false) + var connections: NIOAtomic var connectionCount: NIOAtomic = .makeAtomic(value: 0) private let activeConnCounterHandler: CountActiveConnectionsHandler var activeConnections: Int { @@ -233,6 +234,9 @@ internal final class HTTPBin { let activeConnCounterHandler = CountActiveConnectionsHandler() self.activeConnCounterHandler = activeConnCounterHandler + let connections = NIOAtomic.makeAtomic(value: 0) + self.connections = connections + self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .serverChannelInitializer { channel in @@ -261,10 +265,10 @@ internal final class HTTPBin { }.flatMap { if ssl { return HTTPBin.configureTLS(channel: channel).flatMap { - channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge)) + channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1))) } } else { - return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge)) + return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1))) } } } @@ -357,9 +361,6 @@ internal struct HTTPResponseBuilder { } } -let globalRequestCounter = NIOAtomic.makeAtomic(value: 0) -let globalConnectionCounter = NIOAtomic.makeAtomic(value: 0) - internal struct RequestInfo: Codable { var data: String var requestNumber: Int @@ -378,13 +379,13 @@ internal final class HttpBinHandler: ChannelInboundHandler { let maxChannelAge: TimeAmount? var shouldClose = false var isServingRequest = false - let myConnectionNumber: Int - var currentRequestNumber: Int = -1 + let connectionId: Int + var requestId: Int = 0 - init(channelPromise: EventLoopPromise? = nil, maxChannelAge: TimeAmount? = nil) { + init(channelPromise: EventLoopPromise? = nil, maxChannelAge: TimeAmount? = nil, connectionId: Int) { self.channelPromise = channelPromise self.maxChannelAge = maxChannelAge - self.myConnectionNumber = globalConnectionCounter.add(1) + self.connectionId = connectionId } func handlerAdded(context: ChannelHandlerContext) { @@ -424,7 +425,7 @@ internal final class HttpBinHandler: ChannelInboundHandler { switch self.unwrapInboundIn(data) { case .head(let req): self.responseHeaders = HTTPHeaders() - self.currentRequestNumber = globalRequestCounter.add(1) + self.requestId += 1 self.parseAndSetOptions(from: req) let urlComponents = URLComponents(string: req.uri)! switch urlComponents.percentEncodedPath { @@ -552,8 +553,15 @@ internal final class HttpBinHandler: ChannelInboundHandler { context.write(wrapOutboundOut(.head(response.head)), promise: nil) if let body = response.body { let requestInfo = RequestInfo(data: String(buffer: body), - requestNumber: self.currentRequestNumber, - connectionNumber: self.myConnectionNumber) + requestNumber: self.requestId, + connectionNumber: self.connectionId) + let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, + allocator: context.channel.allocator) + context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) + } else { + let requestInfo = RequestInfo(data: "", + requestNumber: self.requestId, + connectionNumber: self.connectionId) let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, allocator: context.channel.allocator) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 3e95a5879..5d8ba873f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -110,6 +110,8 @@ extension HTTPClientTests { ("testAllMethodsLog", testAllMethodsLog), ("testClosingIdleConnectionsInPoolLogsInTheBackground", testClosingIdleConnectionsInPoolLogsInTheBackground), ("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL), + ("testContentLengthTooLongFails", testContentLengthTooLongFails), + ("testContentLengthTooShortFails", testContentLengthTooShortFails), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 7cbf81a53..faf3099e0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1713,7 +1713,8 @@ class HTTPClientTests: XCTestCase { // req 1 and 2 cannot share the same connection (close header) XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) - XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + XCTAssertEqual(stats1.requestNumber, 1) + XCTAssertEqual(stats2.requestNumber, 1) // req 2 and 3 should share the same connection (keep-alive is default) XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) @@ -1742,7 +1743,8 @@ class HTTPClientTests: XCTestCase { // req 1 and 2 cannot share the same connection (close header) XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) - XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + XCTAssertEqual(stats1.requestNumber, 1) + XCTAssertEqual(stats2.requestNumber, 1) // req 2 and 3 should share the same connection (keep-alive is default) XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) @@ -1773,7 +1775,7 @@ class HTTPClientTests: XCTestCase { // req 1 and 2 cannot share the same connection (close header) XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) - XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + XCTAssertEqual(stats2.requestNumber, 1) // req 2 and 3 should share the same connection (keep-alive is default) XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) @@ -1805,7 +1807,7 @@ class HTTPClientTests: XCTestCase { // req 1 and 2 cannot share the same connection (close header) XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber) - XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber) + XCTAssertEqual(stats2.requestNumber, 1) // req 2 and 3 should share the same connection (keep-alive is default) XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber) @@ -2036,6 +2038,7 @@ class HTTPClientTests: XCTestCase { defer { XCTAssertNoThrow(try httpClient.syncShutdown()) XCTAssertNoThrow(try httpServer.stop()) + XCTAssertNoThrow(try elg.syncShutdownGracefully()) } let delegate = TestDelegate(eventLoop: second) @@ -2051,4 +2054,59 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try future.wait()) } + + func testContentLengthTooLongFails() throws { + let url = self.defaultHTTPBinURLPrefix + "/post" + XCTAssertThrowsError( + try self.defaultClient.execute(request: + Request(url: url, + body: .stream(length: 10) { streamWriter in + let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + DispatchQueue(label: "content-length-test").async { + streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise) + } + return promise.futureResult + })).wait()) { error in + XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) + } + // Quickly try another request and check that it works. + let response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait() + guard var body = response.body else { + XCTFail("Body missing: \(response)") + return + } + guard let info = try body.readJSONDecodable(RequestInfo.self, length: body.readableBytes) else { + XCTFail("Cannot parse body: \(body.readableBytesView.map { $0 })") + return + } + XCTAssertEqual(info.connectionNumber, 1) + XCTAssertEqual(info.requestNumber, 1) + } + + // currently gets stuck because of #250 the server just never replies + func testContentLengthTooShortFails() throws { + let url = self.defaultHTTPBinURLPrefix + "/post" + let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" + XCTAssertThrowsError( + try self.defaultClient.execute(request: + Request(url: url, + body: .stream(length: 1) { streamWriter in + streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) + })).wait()) { error in + XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) + } + // Quickly try another request and check that it works. If we by accident wrote some extra bytes into the + // stream (and reuse the connection) that could cause problems. + let response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait() + guard var body = response.body else { + XCTFail("Body missing: \(response)") + return + } + guard let info = try body.readJSONDecodable(RequestInfo.self, length: body.readableBytes) else { + XCTFail("Cannot parse body: \(body.readableBytesView.map { $0 })") + return + } + XCTAssertEqual(info.connectionNumber, 1) + XCTAssertEqual(info.requestNumber, 1) + } } diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift index b4d88344c..51e6ff016 100644 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift @@ -33,6 +33,7 @@ extension RequestValidationTests { ("testGET_HEAD_DELETE_CONNECTRequestCanHaveBody", testGET_HEAD_DELETE_CONNECTRequestCanHaveBody), ("testInvalidHeaderFieldNames", testInvalidHeaderFieldNames), ("testValidHeaderFieldNames", testValidHeaderFieldNames), + ("testMultipleContentLengthOnNilStreamLength", testMultipleContentLengthOnNilStreamLength), ] } } diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift index 5f4db05b7..fe3cb9330 100644 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift @@ -104,4 +104,14 @@ class RequestValidationTests: XCTestCase { XCTAssertNoThrow(try headers.validate(method: .GET, body: nil)) } + + func testMultipleContentLengthOnNilStreamLength() { + var headers = HTTPHeaders([("Content-Length", "1"), ("Content-Length", "2")]) + var buffer = ByteBufferAllocator().buffer(capacity: 10) + buffer.writeBytes([UInt8](repeating: 12, count: 10)) + let body: HTTPClient.Body = .stream { writer in + writer.write(.byteBuffer(buffer)) + } + XCTAssertThrowsError(try headers.validate(method: .PUT, body: body)) + } }