diff --git a/Package.swift b/Package.swift index f2e606a93..b579b608a 100644 --- a/Package.swift +++ b/Package.swift @@ -21,9 +21,9 @@ let package = Package( .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.27.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.29.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.13.0"), - .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.3.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.9.1"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.5.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), ], @@ -31,12 +31,12 @@ let package = Package( .target( name: "AsyncHTTPClient", dependencies: ["NIO", "NIOHTTP1", "NIOSSL", "NIOConcurrencyHelpers", "NIOHTTPCompression", - "NIOFoundationCompat", "NIOTransportServices", "Logging"] + "NIOFoundationCompat", "NIOTransportServices", "Logging", "NIOSOCKS"] ), .testTarget( name: "AsyncHTTPClientTests", dependencies: ["NIO", "NIOConcurrencyHelpers", "NIOSSL", "AsyncHTTPClient", "NIOFoundationCompat", - "NIOTestUtils", "Logging"] + "NIOTestUtils", "Logging", "NIOSOCKS"] ), ] ) diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index ec549d993..51a1c8c4d 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -18,6 +18,7 @@ import NIO import NIOConcurrencyHelpers import NIOHTTP1 import NIOHTTPCompression +import NIOSOCKS import NIOSSL import NIOTLS import NIOTransportServices @@ -883,7 +884,7 @@ extension HTTPClient.Configuration { } extension ChannelPipeline { - func syncAddProxyHandler(host: String, port: Int, authorization: HTTPClient.Authorization?) throws { + func syncAddHTTPProxyHandler(host: String, port: Int, authorization: HTTPClient.Authorization?) throws { let encoder = HTTPRequestEncoder() let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .forwardBytes)) let handler = HTTPClientProxyHandler(host: host, port: port, authorization: authorization) { channel in @@ -900,6 +901,12 @@ extension ChannelPipeline { try sync.addHandler(handler) } + func syncAddSOCKSProxyHandler(host: String, port: Int) throws { + let handler = SOCKSClientHandler(targetAddress: .domain(host, port: port)) + let sync = self.syncOperations + try sync.addHandler(handler) + } + func syncAddLateSSLHandlerIfNeeded(for key: ConnectionPool.Key, sslContext: NIOSSLContext, handshakePromise: EventLoopPromise) { diff --git a/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift b/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift index ebdfbfa24..e2b891c8b 100644 --- a/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPClientProxyHandler.swift @@ -14,8 +14,9 @@ import NIO import NIOHTTP1 +import NIOSOCKS -public extension HTTPClient.Configuration { +extension HTTPClient.Configuration { /// Proxy server configuration /// Specifies the remote address of an HTTP proxy. /// @@ -26,31 +27,60 @@ public extension HTTPClient.Configuration { /// If a `TLSConfiguration` is used in conjunction with `HTTPClient.Configuration.Proxy`, /// TLS will be established _after_ successful proxy, between your client /// and the destination server. - struct Proxy { + public struct Proxy { + enum ProxyType: Hashable { + case http(HTTPClient.Authorization?) + case socks + } + /// Specifies Proxy server host. public var host: String /// Specifies Proxy server port. public var port: Int /// Specifies Proxy server authorization. - public var authorization: HTTPClient.Authorization? + public var authorization: HTTPClient.Authorization? { + set { + precondition(self.type == .http(self.authorization), "SOCKS authorization support is not yet implemented.") + self.type = .http(newValue) + } + + get { + switch self.type { + case .http(let authorization): + return authorization + case .socks: + return nil + } + } + } + + var type: ProxyType - /// Create proxy. + /// Create a HTTP proxy. /// /// - parameters: /// - host: proxy server host. /// - port: proxy server port. public static func server(host: String, port: Int) -> Proxy { - return .init(host: host, port: port, authorization: nil) + return .init(host: host, port: port, type: .http(nil)) } - /// Create proxy. + /// Create a HTTP proxy. /// /// - parameters: /// - host: proxy server host. /// - port: proxy server port. /// - authorization: proxy server authorization. public static func server(host: String, port: Int, authorization: HTTPClient.Authorization? = nil) -> Proxy { - return .init(host: host, port: port, authorization: authorization) + return .init(host: host, port: port, type: .http(authorization)) + } + + /// Create a SOCKSv5 proxy. + /// - parameter host: The SOCKSv5 proxy address. + /// - parameter port: The SOCKSv5 proxy port, defaults to 1080. + /// - returns: A new instance of `Proxy` configured to connect to a `SOCKSv5` server. + public static func socksServer(host: String, port: Int = 1080) -> Proxy { + return .init(host: host, port: port, type: .socks) } } } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 4850c51d8..a9c1a9e22 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -342,8 +342,8 @@ extension HTTPClient { } /// HTTP authentication - public struct Authorization { - private enum Scheme { + public struct Authorization: Hashable { + private enum Scheme: Hashable { case Basic(String) case Bearer(String) } diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index 6069222b1..174bc593e 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -148,9 +148,14 @@ extension NIOClientTCPBootstrap { return bootstrap.channelInitializer { channel in do { if let proxy = configuration.proxy { - try channel.pipeline.syncAddProxyHandler(host: host, - port: port, - authorization: proxy.authorization) + switch proxy.type { + case .http: + try channel.pipeline.syncAddHTTPProxyHandler(host: host, + port: port, + authorization: proxy.authorization) + case .socks: + try channel.pipeline.syncAddSOCKSProxyHandler(host: host, port: port) + } } else if requiresTLS { // We only add the handshake verifier if we need TLS and we're not going through a proxy. // If we're going through a proxy we add it later (outside of this method). diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests+XCTest.swift new file mode 100644 index 000000000..40ef6e0ff --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests+XCTest.swift @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// HTTPClient+SOCKSTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension HTTPClientSOCKSTests { + static var allTests: [(String, (HTTPClientSOCKSTests) -> () throws -> Void)] { + return [ + ("testProxySOCKS", testProxySOCKS), + ("testProxySOCKSBogusAddress", testProxySOCKSBogusAddress), + ("testProxySOCKSFailureNoServer", testProxySOCKSFailureNoServer), + ("testProxySOCKSFailureInvalidServer", testProxySOCKSFailureInvalidServer), + ("testProxySOCKSMisbehavingServer", testProxySOCKSMisbehavingServer), + ] + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift new file mode 100644 index 000000000..3479d86b9 --- /dev/null +++ b/Tests/AsyncHTTPClientTests/HTTPClient+SOCKSTests.swift @@ -0,0 +1,137 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift +import Logging +import NIO +import NIOSOCKS +import XCTest + +class HTTPClientSOCKSTests: XCTestCase { + typealias Request = HTTPClient.Request + + var clientGroup: EventLoopGroup! + var serverGroup: EventLoopGroup! + var defaultHTTPBin: HTTPBin! + var defaultClient: HTTPClient! + var backgroundLogStore: CollectEverythingLogHandler.LogStore! + + var defaultHTTPBinURLPrefix: String { + return "http://localhost:\(self.defaultHTTPBin.port)/" + } + + override func setUp() { + XCTAssertNil(self.clientGroup) + XCTAssertNil(self.serverGroup) + XCTAssertNil(self.defaultHTTPBin) + XCTAssertNil(self.defaultClient) + XCTAssertNil(self.backgroundLogStore) + + self.clientGroup = getDefaultEventLoopGroup(numberOfThreads: 1) + self.serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.defaultHTTPBin = HTTPBin() + self.backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger(label: "\(#function)", factory: { _ in + CollectEverythingLogHandler(logStore: self.backgroundLogStore!) + }) + backgroundLogger.logLevel = .trace + self.defaultClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger) + } + + override func tearDown() { + if let defaultClient = self.defaultClient { + XCTAssertNoThrow(try defaultClient.syncShutdown()) + self.defaultClient = nil + } + + XCTAssertNotNil(self.defaultHTTPBin) + XCTAssertNoThrow(try self.defaultHTTPBin.shutdown()) + self.defaultHTTPBin = nil + + XCTAssertNotNil(self.clientGroup) + XCTAssertNoThrow(try self.clientGroup.syncShutdownGracefully()) + self.clientGroup = nil + + XCTAssertNotNil(self.serverGroup) + XCTAssertNoThrow(try self.serverGroup.syncShutdownGracefully()) + self.serverGroup = nil + + XCTAssertNotNil(self.backgroundLogStore) + self.backgroundLogStore = nil + } + + func testProxySOCKS() throws { + let socksBin = try MockSOCKSServer(expectedURL: "/socks/test", expectedResponse: "it works!") + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(proxy: .socksServer(host: "localhost"))) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try socksBin.shutdown()) + } + + var response: HTTPClient.Response? + XCTAssertNoThrow(response = try localClient.get(url: "http://localhost/socks/test").wait()) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(ByteBuffer(string: "it works!"), response?.body) + } + + func testProxySOCKSBogusAddress() throws { + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(proxy: .socksServer(host: "127.0.."))) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + } + XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) + } + + // there is no socks server, so we should fail + func testProxySOCKSFailureNoServer() throws { + let localHTTPBin = HTTPBin() + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(proxy: .socksServer(host: "localhost", port: localHTTPBin.port))) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) + } + + // speak to a server that doesn't speak SOCKS + func testProxySOCKSFailureInvalidServer() throws { + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(proxy: .socksServer(host: "localhost"))) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + } + XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) + } + + // test a handshake failure with a misbehaving server + func testProxySOCKSMisbehavingServer() throws { + let socksBin = try MockSOCKSServer(expectedURL: "/socks/test", expectedResponse: "it works!", misbehave: true) + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(proxy: .socksServer(host: "localhost"))) + + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try socksBin.shutdown()) + } + + // the server will send a bogus message in response to the clients request + XCTAssertThrowsError(try localClient.get(url: "http://localhost/socks/test").wait()) + } +} diff --git a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift new file mode 100644 index 000000000..38fa706df --- /dev/null +++ b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift @@ -0,0 +1,130 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import AsyncHTTPClient +import NIO +import NIOHTTP1 +import NIOSOCKS +import XCTest + +struct MockSOCKSError: Error, Hashable { + var description: String +} + +class MockSOCKSServer { + let channel: Channel + + init(expectedURL: String, expectedResponse: String, misbehave: Bool = false, file: String = #file, line: UInt = #line) throws { + let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) + let bootstrap = ServerBootstrap(group: elg) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + let handshakeHandler = SOCKSServerHandshakeHandler() + return channel.pipeline.addHandlers([ + handshakeHandler, + SOCKSTestHandler(handshakeHandler: handshakeHandler, misbehave: misbehave), + TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line), + ]) + } + self.channel = try bootstrap.bind(host: "localhost", port: 1080).wait() + } + + func shutdown() throws { + try self.channel.close().wait() + } +} + +class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = ClientMessage + + let handshakeHandler: SOCKSServerHandshakeHandler + let misbehave: Bool + + init(handshakeHandler: SOCKSServerHandshakeHandler, misbehave: Bool) { + self.handshakeHandler = handshakeHandler + self.misbehave = misbehave + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + guard context.channel.isActive else { + return + } + + let message = self.unwrapInboundIn(data) + switch message { + case .greeting: + context.writeAndFlush(.init( + ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired))), promise: nil) + case .authenticationData: + context.fireErrorCaught(MockSOCKSError(description: "Received authentication data but didn't receive any.")) + case .request(let request): + guard !self.misbehave else { + context.writeAndFlush( + .init(ServerMessage.authenticationData(context.channel.allocator.buffer(string: "bad server!"), complete: true)), promise: nil + ) + return + } + context.writeAndFlush(.init( + ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType))), promise: nil) + context.channel.pipeline.addHandlers([ + ByteToMessageHandler(HTTPRequestDecoder()), + HTTPResponseEncoder(), + ], position: .after(self)).whenSuccess { + context.channel.pipeline.removeHandler(self, promise: nil) + context.channel.pipeline.removeHandler(self.handshakeHandler, promise: nil) + } + } + } +} + +class TestHTTPServer: ChannelInboundHandler { + typealias InboundIn = HTTPServerRequestPart + typealias OutboundOut = HTTPServerResponsePart + + let expectedURL: String + let expectedResponse: String + let file: String + let line: UInt + var requestCount = 0 + + init(expectedURL: String, expectedResponse: String, file: String, line: UInt) { + self.expectedURL = expectedURL + self.expectedResponse = expectedResponse + self.file = file + self.line = line + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let message = self.unwrapInboundIn(data) + switch message { + case .head(let head): + guard self.requestCount == 0 else { + return + } + XCTAssertEqual(head.uri, self.expectedURL) + self.requestCount += 1 + case .body: + break + case .end: + context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) + context.write(self.wrapOutboundOut(.body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse)))), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + } + } + + func errorCaught(context: ChannelHandlerContext, error: Error) { + context.fireErrorCaught(error) + context.close(promise: nil) + } +} diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 0db0dd9ce..094b0ee2a 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -31,6 +31,7 @@ import XCTest testCase(HTTPClientCookieTests.allTests), testCase(HTTPClientInternalTests.allTests), testCase(HTTPClientNIOTSTests.allTests), + testCase(HTTPClientSOCKSTests.allTests), testCase(HTTPClientTests.allTests), testCase(LRUCacheTests.allTests), testCase(RequestValidationTests.allTests),