Skip to content

Commit 34e3403

Browse files
authored
Task refactoring (#47)
* refactor Task to encapsulate promise
1 parent d32eea0 commit 34e3403

File tree

3 files changed

+46
-43
lines changed

3 files changed

+46
-43
lines changed

Sources/NIOHTTPClient/HTTPHandler.swift

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -236,31 +236,28 @@ internal extension URL {
236236
}
237237
}
238238

239-
public extension HTTPClient {
240-
final class Task<Response> {
239+
extension HTTPClient {
240+
public final class Task<Response> {
241241
public let eventLoop: EventLoop
242-
let future: EventLoopFuture<Response>
242+
let promise: EventLoopPromise<Response>
243243

244244
private var channel: Channel?
245245
private var cancelled: Bool
246246
private let lock: Lock
247247

248-
init(eventLoop: EventLoop, future: EventLoopFuture<Response>) {
248+
public init(eventLoop: EventLoop) {
249249
self.eventLoop = eventLoop
250-
self.future = future
250+
self.promise = eventLoop.makePromise()
251251
self.cancelled = false
252252
self.lock = Lock()
253253
}
254254

255-
func setChannel(_ channel: Channel) -> Channel {
256-
return self.lock.withLock {
257-
self.channel = channel
258-
return channel
259-
}
255+
public var futureResult: EventLoopFuture<Response> {
256+
return self.promise.futureResult
260257
}
261258

262259
public func wait() throws -> Response {
263-
return try self.future.wait()
260+
return try self.promise.futureResult.wait()
264261
}
265262

266263
public func cancel() {
@@ -272,8 +269,19 @@ public extension HTTPClient {
272269
}
273270
}
274271

275-
public func cascade(promise: EventLoopPromise<Response>) {
276-
self.future.cascade(to: promise)
272+
func setChannel(_ channel: Channel) -> Channel {
273+
return self.lock.withLock {
274+
self.channel = channel
275+
return channel
276+
}
277+
}
278+
279+
func succeed(_ value: Response) {
280+
self.promise.succeed(value)
281+
}
282+
283+
func fail(_ error: Error) {
284+
self.promise.fail(error)
277285
}
278286
}
279287
}
@@ -296,17 +304,15 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
296304

297305
let task: HTTPClient.Task<T.Response>
298306
let delegate: T
299-
let promise: EventLoopPromise<T.Response>
300307
let redirectHandler: RedirectHandler<T.Response>?
301308

302309
var state: State = .idle
303310
var pendingRead = false
304311
var mayRead = true
305312

306-
init(task: HTTPClient.Task<T.Response>, delegate: T, promise: EventLoopPromise<T.Response>, redirectHandler: RedirectHandler<T.Response>?) {
313+
init(task: HTTPClient.Task<T.Response>, delegate: T, redirectHandler: RedirectHandler<T.Response>?) {
307314
self.task = task
308315
self.delegate = delegate
309-
self.promise = promise
310316
self.redirectHandler = redirectHandler
311317
}
312318

@@ -347,13 +353,13 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
347353
self.delegate.didSendRequest(task: self.task)
348354

349355
let channel = context.channel
350-
self.promise.futureResult.whenComplete { _ in
356+
self.task.futureResult.whenComplete { _ in
351357
channel.close(promise: nil)
352358
}
353359
case .failure(let error):
354360
self.state = .end
355361
self.delegate.didReceiveError(task: self.task, error)
356-
self.promise.fail(error)
362+
self.task.fail(error)
357363
context.close(promise: nil)
358364
}
359365
}
@@ -410,14 +416,14 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
410416
switch self.state {
411417
case .redirected(let head, let redirectURL):
412418
self.state = .end
413-
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.promise)
419+
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
414420
context.close(promise: nil)
415421
default:
416422
self.state = .end
417423
do {
418-
self.promise.succeed(try self.delegate.didFinishRequest(task: self.task))
424+
self.task.succeed(try self.delegate.didFinishRequest(task: self.task))
419425
} catch {
420-
self.promise.fail(error)
426+
self.task.fail(error)
421427
}
422428
}
423429
}
@@ -433,7 +439,7 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
433439
case .failure(let error):
434440
self.state = .end
435441
self.delegate.didReceiveError(task: self.task, error)
436-
self.promise.fail(error)
442+
self.task.fail(error)
437443
}
438444
}
439445

@@ -442,12 +448,12 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
442448
self.state = .end
443449
let error = HTTPClientError.readTimeout
444450
self.delegate.didReceiveError(task: self.task, error)
445-
self.promise.fail(error)
451+
self.task.fail(error)
446452
} else if (event as? TaskCancelEvent) != nil {
447453
self.state = .end
448454
let error = HTTPClientError.cancelled
449455
self.delegate.didReceiveError(task: self.task, error)
450-
self.promise.fail(error)
456+
self.task.fail(error)
451457
} else {
452458
context.fireUserInboundEventTriggered(event)
453459
}
@@ -461,7 +467,7 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
461467
self.state = .end
462468
let error = HTTPClientError.remoteConnectionClosed
463469
self.delegate.didReceiveError(task: self.task, error)
464-
self.promise.fail(error)
470+
self.task.fail(error)
465471
}
466472
}
467473

@@ -476,12 +482,12 @@ internal class TaskHandler<T: HTTPClientResponseDelegate>: ChannelInboundHandler
476482
default:
477483
self.state = .end
478484
self.delegate.didReceiveError(task: self.task, error)
479-
self.promise.fail(error)
485+
self.task.fail(error)
480486
}
481487
default:
482488
self.state = .end
483489
self.delegate.didReceiveError(task: self.task, error)
484-
self.promise.fail(error)
490+
self.task.fail(error)
485491
}
486492
}
487493
}
@@ -556,6 +562,6 @@ internal struct RedirectHandler<T> {
556562
request.headers.remove(name: "Proxy-Authorization")
557563
}
558564

559-
return self.execute(request).cascade(promise: promise)
565+
return self.execute(request).futureResult.cascade(to: promise)
560566
}
561567
}

Sources/NIOHTTPClient/SwiftNIOHTTP.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,12 @@ public class HTTPClient {
105105

106106
public func execute(request: Request, timeout: Timeout? = nil) -> EventLoopFuture<Response> {
107107
let accumulator = ResponseAccumulator(request: request)
108-
return self.execute(request: request, delegate: accumulator, timeout: timeout).future
108+
return self.execute(request: request, delegate: accumulator, timeout: timeout).futureResult
109109
}
110110

111111
public func execute<T: HTTPClientResponseDelegate>(request: Request, delegate: T, timeout: Timeout? = nil) -> Task<T.Response> {
112112
let timeout = timeout ?? configuration.timeout
113113
let eventLoop = self.eventLoopGroup.next()
114-
let promise: EventLoopPromise<T.Response> = eventLoop.makePromise()
115114

116115
let redirectHandler: RedirectHandler<T.Response>?
117116
if self.configuration.followRedirects {
@@ -122,7 +121,7 @@ public class HTTPClient {
122121
redirectHandler = nil
123122
}
124123

125-
let task = Task(eventLoop: eventLoop, future: promise.futureResult)
124+
let task = Task<T.Response>(eventLoop: eventLoop)
126125

127126
var bootstrap = ClientBootstrap(group: self.eventLoopGroup)
128127
.channelOption(ChannelOptions.socket(SocketOptionLevel(IPPROTO_TCP), TCP_NODELAY), value: 1)
@@ -143,7 +142,7 @@ public class HTTPClient {
143142
return channel.eventLoop.makeSucceededFuture(())
144143
}
145144
}.flatMap {
146-
let taskHandler = TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: redirectHandler)
145+
let taskHandler = TaskHandler(task: task, delegate: delegate, redirectHandler: redirectHandler)
147146
return channel.pipeline.addHandler(taskHandler)
148147
}
149148
}
@@ -161,7 +160,7 @@ public class HTTPClient {
161160
channel.writeAndFlush(request)
162161
}
163162
.whenFailure { error in
164-
promise.fail(error)
163+
task.fail(error)
165164
}
166165

167166
return task

Tests/NIOHTTPClientTests/SwiftNIOHTTPTests.swift

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import Foundation
1616
import NIO
1717
import NIOConcurrencyHelpers
1818
import NIOFoundationCompat
19-
@testable import NIOHTTP1
19+
import NIOHTTP1
2020
@testable import NIOHTTPClient
2121
import NIOSSL
2222
import XCTest
@@ -40,11 +40,10 @@ class SwiftHTTPTests: XCTestCase {
4040
func testHTTPPartsHandler() throws {
4141
let channel = EmbeddedChannel()
4242
let recorder = RecordingHandler<HTTPClientResponsePart, HTTPClientRequestPart>()
43-
let promise: EventLoopPromise<Void> = channel.eventLoop.makePromise()
44-
let task = Task(eventLoop: channel.eventLoop, future: promise.futureResult)
43+
let task = Task<Void>(eventLoop: channel.eventLoop)
4544

4645
try channel.pipeline.addHandler(recorder).wait()
47-
try channel.pipeline.addHandler(TaskHandler(task: task, delegate: TestHTTPDelegate(), promise: promise, redirectHandler: nil)).wait()
46+
try channel.pipeline.addHandler(TaskHandler(task: task, delegate: TestHTTPDelegate(), redirectHandler: nil)).wait()
4847

4948
var request = try Request(url: "http://localhost/get")
5049
request.headers.add(name: "X-Test-Header", value: "X-Test-Value")
@@ -68,9 +67,8 @@ class SwiftHTTPTests: XCTestCase {
6867
func testHTTPPartsHandlerMultiBody() throws {
6968
let channel = EmbeddedChannel()
7069
let delegate = TestHTTPDelegate()
71-
let promise: EventLoopPromise<Void> = channel.eventLoop.makePromise()
72-
let task = Task(eventLoop: channel.eventLoop, future: promise.futureResult)
73-
let handler = TaskHandler(task: task, delegate: delegate, promise: promise, redirectHandler: nil)
70+
let task = Task<Void>(eventLoop: channel.eventLoop)
71+
let handler = TaskHandler(task: task, delegate: delegate, redirectHandler: nil)
7472

7573
try channel.pipeline.addHandler(handler).wait()
7674

@@ -357,7 +355,7 @@ class SwiftHTTPTests: XCTestCase {
357355
let delegate = CopyingDelegate { part in
358356
writer.write(.byteBuffer(part))
359357
}
360-
return httpClient.execute(request: request, delegate: delegate).future
358+
return httpClient.execute(request: request, delegate: delegate).futureResult
361359
} catch {
362360
return httpClient.eventLoopGroup.next().makeFailedFuture(error)
363361
}
@@ -393,7 +391,7 @@ class SwiftHTTPTests: XCTestCase {
393391
let delegate = CopyingDelegate { _ in
394392
httpClient.eventLoopGroup.next().makeFailedFuture(HTTPClientError.invalidProxyResponse)
395393
}
396-
return httpClient.execute(request: request, delegate: delegate).future
394+
return httpClient.execute(request: request, delegate: delegate).futureResult
397395
} catch {
398396
return httpClient.eventLoopGroup.next().makeFailedFuture(error)
399397
}
@@ -442,7 +440,7 @@ class SwiftHTTPTests: XCTestCase {
442440

443441
let request = try Request(url: "http://localhost:\(httpBin.port)/custom")
444442
let delegate = BackpressureTestDelegate(promise: httpClient.eventLoopGroup.next().makePromise())
445-
let future = httpClient.execute(request: request, delegate: delegate).future
443+
let future = httpClient.execute(request: request, delegate: delegate).futureResult
446444

447445
let channel = try promise.futureResult.wait()
448446

0 commit comments

Comments
 (0)