diff --git a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift index 75f16f52a..6199f33ff 100644 --- a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift +++ b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift @@ -30,7 +30,7 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { public typealias Response = Progress private let filePath: String - private let io: NonBlockingFileIO + private(set) var fileIOThreadPool: NIOThreadPool? private let reportHead: ((HTTPResponseHead) -> Void)? private let reportProgress: ((Progress) -> Void)? @@ -47,14 +47,46 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. - public init( + public convenience init( path: String, - pool: NIOThreadPool = NIOThreadPool(numberOfThreads: 1), + pool: NIOThreadPool, reportHead: ((HTTPResponseHead) -> Void)? = nil, reportProgress: ((Progress) -> Void)? = nil ) throws { - pool.start() - self.io = NonBlockingFileIO(threadPool: pool) + try self.init(path: path, pool: .some(pool), reportHead: reportHead, reportProgress: reportProgress) + } + + /// Initializes a new file download delegate and uses the shared thread pool of the ``HTTPClient`` for file I/O. + /// + /// - parameters: + /// - path: Path to a file you'd like to write the download to. + /// - reportHead: A closure called when the response head is available. + /// - reportProgress: A closure called when a body chunk has been downloaded, with + /// the total byte count and download byte count passed to it as arguments. The callbacks + /// will be invoked in the same threading context that the delegate itself is invoked, + /// as controlled by `EventLoopPreference`. + public convenience init( + path: String, + reportHead: ((HTTPResponseHead) -> Void)? = nil, + reportProgress: ((Progress) -> Void)? = nil + ) throws { + try self.init(path: path, pool: nil, reportHead: reportHead, reportProgress: reportProgress) + } + + private init( + path: String, + pool: NIOThreadPool?, + reportHead: ((HTTPResponseHead) -> Void)? = nil, + reportProgress: ((Progress) -> Void)? = nil + ) throws { + if let pool = pool { + self.fileIOThreadPool = pool + } else { + // we should use the shared thread pool from the HTTPClient which + // we will get from the `HTTPClient.Task` + self.fileIOThreadPool = nil + } + self.filePath = path self.reportHead = reportHead @@ -79,16 +111,25 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { + let threadPool: NIOThreadPool = { + guard let pool = self.fileIOThreadPool else { + let pool = task.fileIOThreadPool + self.fileIOThreadPool = pool + return pool + } + return pool + }() + let io = NonBlockingFileIO(threadPool: threadPool) self.progress.receivedBytes += buffer.readableBytes self.reportProgress?(self.progress) let writeFuture: EventLoopFuture if let fileHandleFuture = self.fileHandleFuture { writeFuture = fileHandleFuture.flatMap { - self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) } } else { - let fileHandleFuture = self.io.openFile( + let fileHandleFuture = io.openFile( path: self.filePath, mode: .write, flags: .allowFileCreation(), @@ -96,7 +137,7 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { ) self.fileHandleFuture = fileHandleFuture writeFuture = fileHandleFuture.flatMap { - self.io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) } } diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 9cf84a2cb..ab4f7815e 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -72,6 +72,11 @@ public class HTTPClient { let eventLoopGroupProvider: EventLoopGroupProvider let configuration: Configuration let poolManager: HTTPConnectionPool.Manager + + /// Shared thread pool used for file IO. It is lazily created on first access of ``Task/fileIOThreadPool``. + private var fileIOThreadPool: NIOThreadPool? + private let fileIOThreadPoolLock = Lock() + private var state: State private let stateLock = Lock() @@ -213,6 +218,16 @@ public class HTTPClient { } } + private func shutdownFileIOThreadPool(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { + self.fileIOThreadPoolLock.withLockVoid { + guard let fileIOThreadPool = fileIOThreadPool else { + callback(nil) + return + } + fileIOThreadPool.shutdownGracefully(queue: queue, callback) + } + } + private func shutdown(requiresCleanClose: Bool, queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { do { try self.stateLock.withLock { @@ -241,15 +256,28 @@ public class HTTPClient { let error: Error? = (requiresClean && unclean) ? HTTPClientError.uncleanShutdown : nil return (callback, error) } - - self.shutdownEventLoop(queue: queue) { error in - let reportedError = error ?? uncleanError - callback(reportedError) + self.shutdownFileIOThreadPool(queue: queue) { ioThreadPoolError in + self.shutdownEventLoop(queue: queue) { error in + let reportedError = error ?? ioThreadPoolError ?? uncleanError + callback(reportedError) + } } } } } + private func makeOrGetFileIOThreadPool() -> NIOThreadPool { + self.fileIOThreadPoolLock.withLock { + guard let fileIOThreadPool = fileIOThreadPool else { + let fileIOThreadPool = NIOThreadPool(numberOfThreads: ProcessInfo.processInfo.processorCount) + fileIOThreadPool.start() + self.fileIOThreadPool = fileIOThreadPool + return fileIOThreadPool + } + return fileIOThreadPool + } + } + /// Execute `GET` request using specified URL. /// /// - parameters: @@ -562,6 +590,7 @@ public class HTTPClient { case .testOnly_exact(_, delegateOn: let delegateEL): taskEL = delegateEL } + logger.trace("selected EventLoop for task given the preference", metadata: ["ahc-eventloop": "\(taskEL)", "ahc-el-preference": "\(eventLoopPreference)"]) @@ -574,7 +603,8 @@ public class HTTPClient { logger.debug("client is shutting down, failing request") return Task.failedTask(eventLoop: taskEL, error: HTTPClientError.alreadyShutdown, - logger: logger) + logger: logger, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool) } } @@ -597,7 +627,7 @@ public class HTTPClient { } }() - let task = Task(eventLoop: taskEL, logger: logger) + let task = Task(eventLoop: taskEL, logger: logger, makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool) do { let requestBag = try RequestBag( request: request, diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index aeef71ba4..c62c2f7d1 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -17,6 +17,7 @@ import Logging import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 +import NIOPosix import NIOSSL extension HTTPClient { @@ -502,7 +503,7 @@ public protocol HTTPClientResponseDelegate: AnyObject { } extension HTTPClientResponseDelegate { - /// Default implementation of ``HTTPClientResponseDelegate/didSendRequestHead(task:_:)-6khai``. + /// Default implementation of ``HTTPClientResponseDelegate/didSendRequest(task:)-9od5p``. /// /// By default, this does nothing. public func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) {} @@ -622,15 +623,27 @@ extension HTTPClient { private var _isCancelled: Bool = false private var _taskDelegate: HTTPClientTaskDelegate? private let lock = Lock() + private let makeOrGetFileIOThreadPool: () -> NIOThreadPool - init(eventLoop: EventLoop, logger: Logger) { + /// The shared thread pool of a ``HTTPClient`` used for file IO. It is lazily created on first access. + internal var fileIOThreadPool: NIOThreadPool { + self.makeOrGetFileIOThreadPool() + } + + init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.logger = logger + self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool } - static func failedTask(eventLoop: EventLoop, error: Error, logger: Logger) -> Task { - let task = self.init(eventLoop: eventLoop, logger: logger) + static func failedTask( + eventLoop: EventLoop, + error: Error, + logger: Logger, + makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool + ) -> Task { + let task = self.init(eventLoop: eventLoop, logger: logger, makeOrGetFileIOThreadPool: makeOrGetFileIOThreadPool) task.promise.fail(error) return task } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift index 3be2c79a6..9114df259 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests+XCTest.swift @@ -36,6 +36,7 @@ extension HTTPClientInternalTests { ("testConnectErrorCalloutOnCorrectEL", testConnectErrorCalloutOnCorrectEL), ("testInternalRequestURI", testInternalRequestURI), ("testHasSuffix", testHasSuffix), + ("testSharedThreadPoolIsIdenticalForAllDelegates", testSharedThreadPoolIsIdenticalForAllDelegates), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift index 492bb4c35..234185eb6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift @@ -554,4 +554,39 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertFalse(elements.hasSuffix([0, 0, 0])) } } + + /// test to verify that we actually share the same thread pool across all ``FileDownloadDelegate``s for a given ``HTTPClient`` + func testSharedThreadPoolIsIdenticalForAllDelegates() throws { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) + XCTAssertNoThrow(try httpBin.shutdown()) + } + var request = try Request(url: "http://localhost:\(httpBin.port)/events/10/content-length") + request.headers.add(name: "Accept", value: "text/event-stream") + + let filePaths = (0..<10).map { _ in + TemporaryFileHelpers.makeTemporaryFilePath() + } + defer { + for filePath in filePaths { + TemporaryFileHelpers.removeTemporaryFile(at: filePath) + } + } + let delegates = try filePaths.map { + try FileDownloadDelegate(path: $0) + } + + let resultFutures = delegates.map { delegate in + httpClient.execute( + request: request, + delegate: delegate + ).futureResult + } + _ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait() + let threadPools = delegates.map { $0.fileIOThreadPool } + let firstThreadPool = threadPools.first ?? nil + XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool }) + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 8f7d4dfce..7cd9ef83d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -283,6 +283,23 @@ enum TemporaryFileHelpers { return try body(path) } + internal static func makeTemporaryFilePath( + directory: String = temporaryDirectory + ) -> String { + let (fd, path) = self.openTemporaryFile() + close(fd) + try! FileManager.default.removeItem(atPath: path) + return path + } + + internal static func removeTemporaryFile( + at path: String + ) { + if FileManager.default.fileExists(atPath: path) { + try? FileManager.default.removeItem(atPath: path) + } + } + internal static func fileSize(path: String) throws -> Int? { return try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 9e7072c19..6993c0df9 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -771,6 +771,17 @@ final class RequestBagTests: XCTestCase { } } +extension HTTPClient.Task { + convenience init( + eventLoop: EventLoop, + logger: Logger + ) { + self.init(eventLoop: eventLoop, logger: logger) { + preconditionFailure("thread pool not needed in tests") + } + } +} + class UploadCountingDelegate: HTTPClientResponseDelegate { typealias Response = Void