diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 89cbcfcf9..d17612c38 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -224,6 +224,36 @@ final class ConnectionPool { fileprivate var closePromise: EventLoopPromise var closeFuture: EventLoopFuture + + func removeIdleConnectionHandlersForLease() -> EventLoopFuture { + return self.channel.eventLoop.flatSubmit { + self.removeHandler(IdleStateHandler.self).flatMap { () -> EventLoopFuture in + self.channel.pipeline.handler(type: IdlePoolConnectionHandler.self).flatMap { idleHandler in + self.channel.pipeline.removeHandler(idleHandler).flatMapError { _ in + self.channel.eventLoop.makeSucceededFuture(()) + }.map { + idleHandler.hasNotSentClose && self.channel.isActive + } + }.flatMapError { error in + // These handlers are only added on connection release, they are not added + // when a connection is made to be instantly leased, so we ignore this error + if let channelError = error as? ChannelPipelineError, channelError == .notFound { + return self.channel.eventLoop.makeSucceededFuture(self.channel.isActive) + } else { + return self.channel.eventLoop.makeFailedFuture(error) + } + } + }.flatMap { channelIsUsable in + if channelIsUsable { + return self.channel.eventLoop.makeSucceededFuture(self) + } else { + return self.channel.eventLoop.makeFailedFuture(InactiveChannelError()) + } + } + } + } + + struct InactiveChannelError: Error {} } /// A connection provider of `HTTP/1.1` connections with a given `Key` (host, scheme, port) @@ -294,7 +324,14 @@ final class ConnectionPool { let action = self.stateLock.withLock { self.state.connectionAction(for: preference) } switch action { case .leaseConnection(let connection): - return connection.channel.eventLoop.makeSucceededFuture(connection) + return connection.removeIdleConnectionHandlersForLease().flatMapError { _ in + connection.closeFuture.flatMap { // We ensure close actions are run first + let defaultEventLoop = self.stateLock.withLock { + self.state.defaultEventLoop + } + return self.makeConnection(on: preference.bestEventLoop ?? defaultEventLoop) + } + } case .makeConnection(let eventLoop): return self.makeConnection(on: eventLoop) case .leaseFutureConnection(let futureConnection): @@ -453,7 +490,7 @@ final class ConnectionPool { fileprivate struct State { /// The default `EventLoop` to use for this `HTTP1ConnectionProvider` - private let defaultEventLoop: EventLoop + let defaultEventLoop: EventLoop /// The maximum number of connections to a certain (host, scheme, port) tuple. private let maximumConcurrentConnections: Int = 8 @@ -476,7 +513,11 @@ final class ConnectionPool { fileprivate var activity: Activity = .opened - fileprivate var pending: Int = 0 + fileprivate var pending: Int = 0 { + didSet { + assert(self.pending >= 0) + } + } private let parentPool: ConnectionPool diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 99ec37f2e..da6ca3952 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -305,7 +305,7 @@ public class HTTPClient { redirectHandler = nil } - let task = Task(eventLoop: taskEL) + let task = Task(eventLoop: taskEL, poolingTimeout: self.configuration.maximumAllowedIdleTimeInConnectionPool) self.stateLock.withLock { self.tasks[task.id] = task } @@ -322,7 +322,6 @@ public class HTTPClient { connection.flatMap { connection -> EventLoopFuture in let channel = connection.channel let addedFuture: EventLoopFuture - switch self.configuration.decompression { case .disabled: addedFuture = channel.eventLoop.makeSucceededFuture(()) @@ -408,6 +407,8 @@ public class HTTPClient { public var redirectConfiguration: RedirectConfiguration /// Default client timeout, defaults to no timeouts. public var timeout: Timeout + /// Timeout of pooled connections + public var maximumAllowedIdleTimeInConnectionPool: TimeAmount? /// Upstream proxy, defaults to no proxy. public var proxy: Proxy? /// Enables automatic body decompression. Supported algorithms are gzip and deflate. @@ -418,30 +419,68 @@ public class HTTPClient { public init(tlsConfiguration: TLSConfiguration? = nil, redirectConfiguration: RedirectConfiguration? = nil, timeout: Timeout = Timeout(), + maximumAllowedIdleTimeInConnectionPool: TimeAmount, proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false, decompression: Decompression = .disabled) { self.tlsConfiguration = tlsConfiguration self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration() self.timeout = timeout + self.maximumAllowedIdleTimeInConnectionPool = maximumAllowedIdleTimeInConnectionPool self.proxy = proxy self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown self.decompression = decompression } + public init(tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + maximumAllowedIdleTimeInConnectionPool: .seconds(60), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + } + public init(certificateVerification: CertificateVerification, redirectConfiguration: RedirectConfiguration? = nil, timeout: Timeout = Timeout(), + maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), proxy: Proxy? = nil, ignoreUncleanSSLShutdown: Bool = false, decompression: Decompression = .disabled) { self.tlsConfiguration = TLSConfiguration.forClient(certificateVerification: certificateVerification) self.redirectConfiguration = redirectConfiguration ?? RedirectConfiguration() self.timeout = timeout + self.maximumAllowedIdleTimeInConnectionPool = maximumAllowedIdleTimeInConnectionPool self.proxy = proxy self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown self.decompression = decompression } + + public init(certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled) { + self.init( + certificateVerification: certificateVerification, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + maximumAllowedIdleTimeInConnectionPool: .seconds(60), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + } } /// Specifies how `EventLoopGroup` will be created and establishes lifecycle ownership. @@ -490,6 +529,19 @@ public class HTTPClient { public static func delegateAndChannel(on eventLoop: EventLoop) -> EventLoopPreference { return EventLoopPreference(.delegateAndChannel(on: eventLoop)) } + + var bestEventLoop: EventLoop? { + switch self.preference { + case .delegate(on: let el): + return el + case .delegateAndChannel(on: let el): + return el + case .testOnly_exact(channelOn: let el, delegateOn: _): + return el + case .indifferent: + return nil + } + } } /// Specifies decompression settings. diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 6320266c9..04c8cb7af 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -497,13 +497,15 @@ extension HTTPClient { var cancelled: Bool let lock: Lock let id = UUID() + let poolingTimeout: TimeAmount? - init(eventLoop: EventLoop) { + init(eventLoop: EventLoop, poolingTimeout: TimeAmount? = nil) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.completion = self.promise.futureResult.map { _ in } self.cancelled = false self.lock = Lock() + self.poolingTimeout = poolingTimeout } static func failedTask(eventLoop: EventLoop, error: Error) -> Task { @@ -571,6 +573,19 @@ extension HTTPClient { connection.removeHandler(IdleStateHandler.self) }.flatMap { connection.removeHandler(TaskHandler.self) + }.flatMap { + let idlePoolConnectionHandler = IdlePoolConnectionHandler() + return connection.channel.pipeline.addHandler(idlePoolConnectionHandler, position: .last).flatMap { + connection.channel.pipeline.addHandler(IdleStateHandler(writeTimeout: self.poolingTimeout), position: .before(idlePoolConnectionHandler)) + } + }.flatMapError { error in + if let error = error as? ChannelError, error == .ioOnClosedChannel { + // We may get this error if channel is released because it is + // closed, it is safe to ignore it + return connection.channel.eventLoop.makeSucceededFuture(()) + } else { + return connection.channel.eventLoop.makeFailedFuture(error) + } }.map { connection.release() }.flatMapError { error in @@ -1008,3 +1023,21 @@ internal struct RedirectHandler { } } } + +class IdlePoolConnectionHandler: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = NIOAny + + let _hasNotSentClose: NIOAtomic = .makeAtomic(value: true) + var hasNotSentClose: Bool { + return self._hasNotSentClose.load() + } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if let idleEvent = event as? IdleStateHandler.IdleStateEvent, idleEvent == .write { + self._hasNotSentClose.store(false) + context.close(promise: nil) + } else { + context.fireUserInboundEventTriggered(event) + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index a66cfcbf2..870ddfc37 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -36,6 +36,7 @@ extension HTTPClientInternalTests { ("testResponseConnectionCloseGet", testResponseConnectionCloseGet), ("testWeNoticeRemoteClosuresEvenWhenConnectionIsIdleInPool", testWeNoticeRemoteClosuresEvenWhenConnectionIsIdleInPool), ("testWeTolerateConnectionsGoingAwayWhilstPoolIsShuttingDown", testWeTolerateConnectionsGoingAwayWhilstPoolIsShuttingDown), + ("testRaceBetweenAsynchronousCloseAndChannelUsabilityDetection", testRaceBetweenAsynchronousCloseAndChannelUsabilityDetection), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index a016b2b31..9288814c6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -623,4 +623,106 @@ class HTTPClientInternalTests: XCTestCase { } XCTAssertNoThrow(try client.syncShutdown()) } + + func testRaceBetweenAsynchronousCloseAndChannelUsabilityDetection() { + final class DelayChannelCloseUntilToldHandler: ChannelOutboundHandler { + typealias OutboundIn = Any + + enum State { + case idling + case delayedClose + case closeDone + } + + var state: State = .idling + let doTheCloseNowFuture: EventLoopFuture + let sawTheClosePromise: EventLoopPromise + + init(doTheCloseNowFuture: EventLoopFuture, + sawTheClosePromise: EventLoopPromise) { + self.doTheCloseNowFuture = doTheCloseNowFuture + self.sawTheClosePromise = sawTheClosePromise + } + + func handlerRemoved(context: ChannelHandlerContext) { + XCTAssertEqual(.closeDone, self.state) + } + + func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + XCTAssertEqual(.idling, self.state) + self.state = .delayedClose + self.sawTheClosePromise.succeed(()) + // let's hold the close until the future's complete + self.doTheCloseNowFuture.whenSuccess { + context.close(mode: mode).map { + XCTAssertEqual(.delayedClose, self.state) + self.state = .closeDone + }.cascade(to: promise) + } + } + } + + let web = HTTPBin() + defer { + XCTAssertNoThrow(try web.shutdown()) + } + + let client = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + XCTAssertNoThrow(try client.syncShutdown()) + } + + let req = try! HTTPClient.Request(url: "http://localhost:\(web.serverChannel.localAddress!.port!)/get", + method: .GET, + body: nil) + + // Let's start by getting a connection so we can mess with the Channel :). + var maybeConnection: ConnectionPool.Connection? + XCTAssertNoThrow(try maybeConnection = client.pool.getConnection(for: req, + preference: .indifferent, + on: client.eventLoopGroup.next(), + deadline: nil).wait()) + guard let connection = maybeConnection else { + XCTFail("couldn't make connection") + return + } + + let channel = connection.channel + let doActualCloseNowPromise = channel.eventLoop.makePromise(of: Void.self) + let sawTheClosePromise = channel.eventLoop.makePromise(of: Void.self) + + XCTAssertNoThrow(try channel.pipeline.addHandler(DelayChannelCloseUntilToldHandler(doTheCloseNowFuture: doActualCloseNowPromise.futureResult, + sawTheClosePromise: sawTheClosePromise), + position: .first).wait()) + client.pool.release(connection) + + XCTAssertNoThrow(try client.execute(request: req).wait()) + + // Now, let's pretend the timeout happened + channel.pipeline.fireUserInboundEventTriggered(IdleStateHandler.IdleStateEvent.write) + + // The Channel's closure should have already been initialised now but still, let's make sure the close + // was initiated + XCTAssertNoThrow(try sawTheClosePromise.futureResult.wait()) + // The Channel should still be active though because we delayed the close through our handler above. + XCTAssertTrue(channel.isActive) + + // When asking for a connection again, we should _not_ get the same one back because we did most of the close, + // similar to what the SSLHandler would do. + let connection2Future = client.pool.getConnection(for: req, + preference: .indifferent, + on: client.eventLoopGroup.next(), + deadline: nil) + doActualCloseNowPromise.succeed(()) + + XCTAssertNoThrow(try maybeConnection = connection2Future.wait()) + guard let connection2 = maybeConnection else { + XCTFail("couldn't get second connection") + return + } + + XCTAssert(connection !== connection2) + client.pool.release(connection2) + XCTAssertTrue(connection2.channel.isActive) + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index e0c843ca4..62e5fe96e 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -167,6 +167,11 @@ internal final class HTTPBin { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) let serverChannel: Channel let isShutdown: NIOAtomic = .makeAtomic(value: false) + var connectionCount: NIOAtomic = .makeAtomic(value: 0) + private let activeConnCounterHandler: CountActiveConnectionsHandler + var activeConnections: Int { + return self.activeConnCounterHandler.currentlyActiveConnections + } enum BindTarget { case unixDomainSocket(String) @@ -204,10 +209,15 @@ internal final class HTTPBin { socketAddress = try! SocketAddress(unixDomainSocketPath: path) } + let activeConnCounterHandler = CountActiveConnectionsHandler() + self.activeConnCounterHandler = activeConnCounterHandler + self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .childChannelInitializer { channel in + .serverChannelInitializer { channel in + channel.pipeline.addHandler(activeConnCounterHandler) + }.childChannelInitializer { channel in guard !refusesConnections else { return channel.eventLoop.makeFailedFuture(HTTPBinError.refusedConnection) } @@ -537,6 +547,27 @@ internal final class HttpBinHandler: ChannelInboundHandler { } } +final class CountActiveConnectionsHandler: ChannelInboundHandler { + typealias InboundIn = Channel + + private let activeConns = NIOAtomic.makeAtomic(value: 0) + + public var currentlyActiveConnections: Int { + return self.activeConns.load() + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let channel = self.unwrapInboundIn(data) + + _ = self.activeConns.add(1) + channel.closeFuture.whenComplete { _ in + _ = self.activeConns.sub(1) + } + + context.fireChannelRead(data) + } +} + internal class HttpBinForSSLUncleanShutdown { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) let serverChannel: Channel diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index f03604e80..07cca50d7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -93,6 +93,8 @@ extension HTTPClientTests { ("testUDSSocketAndPath", testUDSSocketAndPath), ("testUseExistingConnectionOnDifferentEL", testUseExistingConnectionOnDifferentEL), ("testWeRecoverFromServerThatClosesTheConnectionOnUs", testWeRecoverFromServerThatClosesTheConnectionOnUs), + ("testPoolClosesIdleConnections", testPoolClosesIdleConnections), + ("testRacePoolIdleConnectionsAndGet", testRacePoolIdleConnectionsAndGet), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 2bf3f8909..049ce78d4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1631,4 +1631,29 @@ class HTTPClientTests: XCTestCase { XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load()) XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load()) } + + func testPoolClosesIdleConnections() { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(maximumAllowedIdleTimeInConnectionPool: .milliseconds(100))) + defer { + XCTAssertNoThrow(try httpBin.shutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + } + XCTAssertNoThrow(try httpClient.get(url: "http://localhost:\(httpBin.port)/get").wait()) + Thread.sleep(forTimeInterval: 0.2) + XCTAssertEqual(httpBin.activeConnections, 0) + } + + func testRacePoolIdleConnectionsAndGet() { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew, configuration: .init(maximumAllowedIdleTimeInConnectionPool: .milliseconds(10))) + defer { + XCTAssertNoThrow(try httpBin.shutdown()) + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + } + for _ in 1...500 { + XCTAssertNoThrow(try httpClient.get(url: "http://localhost:\(httpBin.port)/get").wait()) + Thread.sleep(forTimeInterval: 0.01 + .random(in: -0.05...0.05)) + } + } }