Skip to content

Commit b723844

Browse files
committed
Add HTTP1.1 connection
1 parent 44207e1 commit b723844

7 files changed

+1112
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,384 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the AsyncHTTPClient open source project
4+
//
5+
// Copyright (c) 2021 Apple Inc. and the AsyncHTTPClient project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import Logging
16+
import NIO
17+
import NIOHTTP1
18+
19+
final class HTTP1ClientChannelHandler: ChannelDuplexHandler {
20+
typealias OutboundIn = HTTPRequestTask
21+
typealias OutboundOut = HTTPClientRequestPart
22+
typealias InboundIn = HTTPClientResponsePart
23+
24+
var channelContext: ChannelHandlerContext!
25+
26+
var state: HTTP1ConnectionStateMachine = .init() {
27+
didSet {
28+
self.channelContext.eventLoop.assertInEventLoop()
29+
30+
self.logger.trace("Connection state did change", metadata: [
31+
"state": "\(String(describing: self.state))",
32+
])
33+
}
34+
}
35+
36+
private var task: HTTPRequestTask?
37+
private var idleReadTimeoutTimer: Scheduled<Void>?
38+
39+
let connection: HTTP1Connection
40+
let logger: Logger
41+
42+
init(connection: HTTP1Connection, logger: Logger) {
43+
self.connection = connection
44+
self.logger = logger
45+
}
46+
47+
func handlerAdded(context: ChannelHandlerContext) {
48+
self.channelContext = context
49+
50+
if context.channel.isActive {
51+
let action = self.state.channelActive(isWritable: context.channel.isWritable)
52+
self.run(action, context: context)
53+
}
54+
}
55+
56+
// MARK: Channel Inbound Handler
57+
58+
func channelActive(context: ChannelHandlerContext) {
59+
let action = self.state.channelActive(isWritable: context.channel.isWritable)
60+
self.run(action, context: context)
61+
}
62+
63+
func channelInactive(context: ChannelHandlerContext) {
64+
let action = self.state.channelInactive()
65+
self.run(action, context: context)
66+
}
67+
68+
func channelWritabilityChanged(context: ChannelHandlerContext) {
69+
self.logger.trace("Channel writability changed", metadata: [
70+
"writable": "\(context.channel.isWritable)",
71+
])
72+
73+
let action = self.state.writabilityChanged(writable: context.channel.isWritable)
74+
self.run(action, context: context)
75+
}
76+
77+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
78+
let httpPart = unwrapInboundIn(data)
79+
80+
self.logger.trace("Message received", metadata: [
81+
"message": "\(httpPart)",
82+
])
83+
84+
let action: HTTP1ConnectionStateMachine.Action
85+
switch httpPart {
86+
case .head(let head):
87+
action = self.state.receivedHTTPResponseHead(head)
88+
case .body(let buffer):
89+
action = self.state.receivedHTTPResponseBodyPart(buffer)
90+
case .end:
91+
action = self.state.receivedHTTPResponseEnd()
92+
}
93+
94+
self.run(action, context: context)
95+
}
96+
97+
func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
98+
context.close(mode: mode, promise: promise)
99+
}
100+
101+
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
102+
self.logger.trace("Write")
103+
104+
let task = self.unwrapOutboundIn(data)
105+
self.task = task
106+
107+
let action = self.state.runNewRequest(idleReadTimeout: task.idleReadTimeout)
108+
self.run(action, context: context)
109+
}
110+
111+
func read(context: ChannelHandlerContext) {
112+
self.logger.trace("Read")
113+
114+
let action = self.state.readEventCaught()
115+
self.run(action, context: context)
116+
}
117+
118+
func errorCaught(context: ChannelHandlerContext, error: Error) {
119+
self.logger.trace("Error caught", metadata: [
120+
"error": "\(error)",
121+
])
122+
123+
let action = self.state.errorHappened(error)
124+
self.run(action, context: context)
125+
}
126+
127+
func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
128+
switch event {
129+
case HTTPConnectionEvent.cancelRequest:
130+
let action = self.state.cancelRequestForClose()
131+
self.run(action, context: context)
132+
default:
133+
context.fireUserInboundEventTriggered(event)
134+
}
135+
}
136+
137+
// MARK: - Run Actions
138+
139+
func run(_ action: HTTP1ConnectionStateMachine.Action, context: ChannelHandlerContext) {
140+
switch action {
141+
case .verifyRequest:
142+
do {
143+
guard self.task!.willExecuteRequest(self) else {
144+
throw HTTPClientError.cancelled
145+
}
146+
147+
let head = try self.verifyRequest(request: self.task!.request)
148+
let action = self.state.requestVerified(head)
149+
self.run(action, context: context)
150+
} catch {
151+
let action = self.state.requestVerificationFailed(error)
152+
self.run(action, context: context)
153+
}
154+
155+
case .sendRequestHead(let head, startBody: let startBody, let idleReadTimeout):
156+
if startBody {
157+
context.write(wrapOutboundOut(.head(head)), promise: nil)
158+
context.flush()
159+
160+
self.task!.requestHeadSent(head)
161+
self.task!.startRequestBodyStream()
162+
} else {
163+
context.write(wrapOutboundOut(.head(head)), promise: nil)
164+
context.write(wrapOutboundOut(.end(nil)), promise: nil)
165+
context.flush()
166+
167+
self.task!.requestHeadSent(head)
168+
}
169+
170+
if let idleReadTimeout = idleReadTimeout {
171+
self.resetIdleReadTimeoutTimer(idleReadTimeout, context: context)
172+
}
173+
174+
case .sendBodyPart(let part):
175+
context.writeAndFlush(wrapOutboundOut(.body(part)), promise: nil)
176+
177+
case .sendRequestEnd:
178+
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: nil)
179+
180+
case .pauseRequestBodyStream:
181+
self.task!.pauseRequestBodyStream()
182+
183+
case .resumeRequestBodyStream:
184+
self.task!.resumeRequestBodyStream()
185+
186+
case .fireChannelActive:
187+
context.fireChannelActive()
188+
189+
case .fireChannelInactive:
190+
context.fireChannelInactive()
191+
192+
case .fireChannelError(let error, let close):
193+
context.fireErrorCaught(error)
194+
if close {
195+
context.close(promise: nil)
196+
}
197+
198+
case .read:
199+
context.read()
200+
201+
case .close:
202+
context.close(promise: nil)
203+
204+
case .wait:
205+
break
206+
207+
case .forwardResponseHead(let head):
208+
self.task!.receiveResponseHead(head)
209+
210+
case .forwardResponseBodyPart(let buffer, let resetReadTimeout):
211+
self.task!.receiveResponseBodyPart(buffer)
212+
213+
if let resetReadTimeout = resetReadTimeout {
214+
self.resetIdleReadTimeoutTimer(resetReadTimeout, context: context)
215+
}
216+
217+
case .forwardResponseEnd(let readPending, let clearReadTimeoutTimer, let closeConnection):
218+
// The order here is very important...
219+
// We first nil our own task property! `taskCompleted` will potentially lead to
220+
// situations in which we get a new request right away. We should finish the task
221+
// after the connection was notified, that we finished. A
222+
// `HTTPClient.shutdown(requiresCleanShutdown: true)` will fail if we do it the
223+
// other way around.
224+
225+
let task = self.task!
226+
self.task = nil
227+
228+
if clearReadTimeoutTimer {
229+
self.clearIdleReadTimeoutTimer()
230+
}
231+
232+
if closeConnection {
233+
context.close(promise: nil)
234+
task.receiveResponseEnd()
235+
} else {
236+
if readPending {
237+
context.read()
238+
}
239+
240+
self.connection.taskCompleted()
241+
task.receiveResponseEnd()
242+
}
243+
244+
case .forwardError(let error, closeConnection: let close, fireChannelError: let fire):
245+
let task = self.task!
246+
self.task = nil
247+
if close {
248+
context.close(promise: nil)
249+
} else {
250+
self.connection.taskCompleted()
251+
}
252+
253+
if fire {
254+
context.fireErrorCaught(error)
255+
}
256+
257+
task.fail(error)
258+
}
259+
}
260+
261+
// MARK: - Private Methods -
262+
263+
private func verifyRequest(request: HTTPClient.Request) throws -> HTTPRequestHead {
264+
var headers = request.headers
265+
266+
if !headers.contains(name: "host") {
267+
let port = request.port
268+
var host = request.host
269+
if !(port == 80 && request.scheme == "http"), !(port == 443 && request.scheme == "https") {
270+
host += ":\(port)"
271+
}
272+
headers.add(name: "host", value: host)
273+
}
274+
275+
try headers.validate(method: request.method, body: request.body)
276+
277+
let head = HTTPRequestHead(
278+
version: .http1_1,
279+
method: request.method,
280+
uri: request.uri,
281+
headers: headers
282+
)
283+
284+
// 3. preparing to send body
285+
286+
// This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example
287+
// in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too.
288+
assert(head.version == HTTPVersion(major: 1, minor: 1),
289+
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")
290+
291+
return head
292+
}
293+
294+
private func resetIdleReadTimeoutTimer(_ idleReadTimeout: TimeAmount, context: ChannelHandlerContext) {
295+
if let oldTimer = self.idleReadTimeoutTimer {
296+
oldTimer.cancel()
297+
}
298+
299+
self.idleReadTimeoutTimer = context.channel.eventLoop.scheduleTask(in: idleReadTimeout) {
300+
let action = self.state.idleReadTimeoutTriggered()
301+
self.run(action, context: context)
302+
}
303+
}
304+
305+
private func clearIdleReadTimeoutTimer() {
306+
guard let oldTimer = self.idleReadTimeoutTimer else {
307+
preconditionFailure("Expected an idleReadTimeoutTimer to exist.")
308+
}
309+
310+
self.idleReadTimeoutTimer = nil
311+
oldTimer.cancel()
312+
}
313+
}
314+
315+
extension HTTP1ClientChannelHandler: HTTP1RequestExecutor {
316+
func writeRequestBodyPart(_ data: IOData, task: HTTPRequestTask) {
317+
guard self.channelContext.eventLoop.inEventLoop else {
318+
return self.channelContext.eventLoop.execute {
319+
self.writeRequestBodyPart(data, task: task)
320+
}
321+
}
322+
323+
guard self.task === task else {
324+
// very likely we got threading issues here...
325+
return
326+
}
327+
328+
let action = self.state.requestStreamPartReceived(data)
329+
self.run(action, context: self.channelContext)
330+
}
331+
332+
func finishRequestBodyStream(task: HTTPRequestTask) {
333+
// ensure the message is received on correct eventLoop
334+
guard self.channelContext.eventLoop.inEventLoop else {
335+
return self.channelContext.eventLoop.execute {
336+
self.finishRequestBodyStream(task: task)
337+
}
338+
}
339+
340+
guard self.task === task else {
341+
// very likely we got threading issues here...
342+
return
343+
}
344+
345+
let action = self.state.requestStreamFinished()
346+
self.run(action, context: self.channelContext)
347+
}
348+
349+
func demandResponseBodyStream(task: HTTPRequestTask) {
350+
// ensure the message is received on correct eventLoop
351+
guard self.channelContext.eventLoop.inEventLoop else {
352+
return self.channelContext.eventLoop.execute {
353+
self.demandResponseBodyStream(task: task)
354+
}
355+
}
356+
357+
guard self.task === task else {
358+
// very likely we got threading issues here...
359+
return
360+
}
361+
362+
self.logger.trace("Downstream requests more response body data")
363+
364+
let action = self.state.forwardMoreBodyParts()
365+
self.run(action, context: self.channelContext)
366+
}
367+
368+
func cancelRequest(task: HTTPRequestTask) {
369+
// ensure the message is received on correct eventLoop
370+
guard self.channelContext.eventLoop.inEventLoop else {
371+
return self.channelContext.eventLoop.execute {
372+
self.cancelRequest(task: task)
373+
}
374+
}
375+
376+
guard self.task === task else {
377+
// very likely we got threading issues here...
378+
return
379+
}
380+
381+
let action = self.state.requestCancelled()
382+
self.run(action, context: self.channelContext)
383+
}
384+
}

0 commit comments

Comments
 (0)