Skip to content

Commit efc4cd1

Browse files
authored
Propagate Connection Closed Information up to top-level (fix #465) (#545)
This PR implements a mechanism to propagate connection loss information from the Lambda runtime client to the runtime loop, enabling termination without backtrace when the connection to the Lambda control plane (or a Mock Server) is lost. The changes are: - When the connection is lost, `ChannelHandlerDelegate.channelInnactive()` now correctly calls `resume(throwing:)` on the ending continuation, for all states (`.waitingForNextInvocation ` and `.sentResponse`). This eliminates the hangs on connection lost.. - I added top-level error handling on `LambdaRuntime._run()` - Add a unit test to check that either `LambdaruntimeError.connectionToControlPlaneLost`, a `ChannelError`, or an `IOError` is thrown when the server closes the connection
1 parent 323b3f2 commit efc4cd1

File tree

9 files changed

+232
-32
lines changed

9 files changed

+232
-32
lines changed

Sources/AWSLambdaRuntime/Lambda+LocalServer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ internal struct LambdaHTTPServer {
452452
await self.responsePool.push(
453453
LocalServerResponse(
454454
id: requestId,
455-
status: .ok,
455+
status: .accepted,
456456
// the local server has no mecanism to collect headers set by the lambda function
457457
headers: HTTPHeaders(),
458458
body: body,

Sources/AWSLambdaRuntime/Lambda.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ public enum Lambda {
4141
var logger = logger
4242
do {
4343
while !Task.isCancelled {
44+
45+
logger.trace("Waiting for next invocation")
4446
let (invocation, writer) = try await runtimeClient.nextInvocation()
4547
logger[metadataKey: "aws-request-id"] = "\(invocation.metadata.requestID)"
4648

@@ -76,14 +78,18 @@ public enum Lambda {
7678
logger: logger
7779
)
7880
)
81+
logger.trace("Handler finished processing invocation")
7982
} catch {
83+
logger.trace("Handler failed processing invocation", metadata: ["Handler error": "\(error)"])
8084
try await writer.reportError(error)
8185
continue
8286
}
87+
logger.handler.metadata.removeValue(forKey: "aws-request-id")
8388
}
8489
} catch is CancellationError {
8590
// don't allow cancellation error to propagate further
8691
}
92+
8793
}
8894

8995
/// The default EventLoop the Lambda is scheduled on.

Sources/AWSLambdaRuntime/LambdaRuntime.swift

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,29 @@ public final class LambdaRuntime<Handler>: Sendable where Handler: StreamingLamb
9494
let ip = String(ipAndPort[0])
9595
guard let port = Int(ipAndPort[1]) else { throw LambdaRuntimeError(code: .invalidPort) }
9696

97-
try await LambdaRuntimeClient.withRuntimeClient(
98-
configuration: .init(ip: ip, port: port),
99-
eventLoop: self.eventLoop,
100-
logger: self.logger
101-
) { runtimeClient in
102-
try await Lambda.runLoop(
103-
runtimeClient: runtimeClient,
104-
handler: handler,
97+
do {
98+
try await LambdaRuntimeClient.withRuntimeClient(
99+
configuration: .init(ip: ip, port: port),
100+
eventLoop: self.eventLoop,
105101
logger: self.logger
106-
)
102+
) { runtimeClient in
103+
try await Lambda.runLoop(
104+
runtimeClient: runtimeClient,
105+
handler: handler,
106+
logger: self.logger
107+
)
108+
}
109+
} catch {
110+
// catch top level errors that have not been handled until now
111+
// this avoids the runtime to crash and generate a backtrace
112+
self.logger.error("LambdaRuntime.run() failed with error", metadata: ["error": "\(error)"])
113+
if let error = error as? LambdaRuntimeError,
114+
error.code != .connectionToControlPlaneLost
115+
{
116+
// if the error is a LambdaRuntimeError but not a connection error,
117+
// we rethrow it to preserve existing behaviour
118+
throw error
119+
}
107120
}
108121

109122
} else {

Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
9797
private let configuration: Configuration
9898

9999
private var connectionState: ConnectionState = .disconnected
100+
100101
private var lambdaState: LambdaState = .idle(previousRequestID: nil)
101102
private var closingState: ClosingState = .notClosing
102103

@@ -118,10 +119,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
118119
} catch {
119120
result = .failure(error)
120121
}
121-
122122
await runtime.close()
123-
124-
//try? await runtime.close()
125123
return try result.get()
126124
}
127125

@@ -163,12 +161,16 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
163161

164162
@usableFromInline
165163
func nextInvocation() async throws -> (Invocation, Writer) {
166-
try await withTaskCancellationHandler {
164+
165+
try Task.checkCancellation()
166+
167+
return try await withTaskCancellationHandler {
167168
switch self.lambdaState {
168169
case .idle:
169170
self.lambdaState = .waitingForNextInvocation
170171
let handler = try await self.makeOrGetConnection()
171172
let invocation = try await handler.nextInvocation()
173+
172174
guard case .waitingForNextInvocation = self.lambdaState else {
173175
fatalError("Invalid state: \(self.lambdaState)")
174176
}
@@ -283,7 +285,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
283285
case (.connecting(let array), .notClosing):
284286
self.connectionState = .disconnected
285287
for continuation in array {
286-
continuation.resume(throwing: LambdaRuntimeError(code: .lostConnectionToControlPlane))
288+
continuation.resume(throwing: LambdaRuntimeError(code: .connectionToControlPlaneLost))
287289
}
288290

289291
case (.connecting(let array), .closing(let continuation)):
@@ -363,7 +365,9 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
363365
)
364366
channel.closeFuture.whenComplete { result in
365367
self.assumeIsolated { runtimeClient in
368+
// close the channel
366369
runtimeClient.channelClosed(channel)
370+
runtimeClient.connectionState = .disconnected
367371
}
368372
}
369373

@@ -382,6 +386,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
382386
return handler
383387
}
384388
} catch {
389+
385390
switch self.connectionState {
386391
case .disconnected, .connected:
387392
fatalError("Unexpected state: \(self.connectionState)")
@@ -430,7 +435,6 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate {
430435
}
431436

432437
isolated.connectionState = .disconnected
433-
434438
}
435439
}
436440
}
@@ -884,8 +888,16 @@ extension LambdaChannelHandler: ChannelInboundHandler {
884888
func channelInactive(context: ChannelHandlerContext) {
885889
// fail any pending responses with last error or assume peer disconnected
886890
switch self.state {
887-
case .connected(_, .waitingForNextInvocation(let continuation)):
888-
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
891+
case .connected(_, let lambdaState):
892+
switch lambdaState {
893+
case .waitingForNextInvocation(let continuation):
894+
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
895+
case .sentResponse(let continuation):
896+
continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel)
897+
case .idle, .sendingResponse, .waitingForResponse:
898+
break
899+
}
900+
self.state = .disconnected
889901
default:
890902
break
891903
}

Sources/AWSLambdaRuntime/LambdaRuntimeError.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ package struct LambdaRuntimeError: Error {
2525

2626
case writeAfterFinishHasBeenSent
2727
case finishAfterFinishHasBeenSent
28-
case lostConnectionToControlPlane
2928
case unexpectedStatusCodeForRequest
3029

3130
case nextInvocationMissingHeaderRequestID

Sources/MockServer/MockHTTPServer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ struct HttpServer {
216216
} else if requestHead.uri.hasSuffix("/response") {
217217
responseStatus = .accepted
218218
} else if requestHead.uri.hasSuffix("/error") {
219-
responseStatus = .ok
219+
responseStatus = .accepted
220220
} else {
221221
responseStatus = .notFound
222222
}

Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ struct LambdaRuntimeClientTests {
4242
.success((self.requestId, self.event))
4343
}
4444

45-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
45+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
4646
#expect(self.requestId == requestId)
4747
#expect(self.event == response)
48-
return .success(())
48+
return .success(nil)
4949
}
5050

5151
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
@@ -102,9 +102,9 @@ struct LambdaRuntimeClientTests {
102102
.success((self.requestId, self.event))
103103
}
104104

105-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
105+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
106106
#expect(self.requestId == requestId)
107-
return .success(())
107+
return .success(nil)
108108
}
109109

110110
mutating func captureHeaders(_ headers: HTTPHeaders) {
@@ -197,10 +197,10 @@ struct LambdaRuntimeClientTests {
197197
.success((self.requestId, self.event))
198198
}
199199

200-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError> {
200+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
201201
#expect(self.requestId == requestId)
202202
#expect(self.event == response)
203-
return .success(())
203+
return .success(nil)
204204
}
205205

206206
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
@@ -238,4 +238,91 @@ struct LambdaRuntimeClientTests {
238238
}
239239
}
240240
}
241+
242+
struct DisconnectAfterSendingResponseBehavior: LambdaServerBehavior {
243+
func getInvocation() -> GetInvocationResult {
244+
.success((UUID().uuidString, "hello"))
245+
}
246+
247+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
248+
// Return "delayed-disconnect" to trigger server closing the connection
249+
// after having accepted the first response
250+
.success("delayed-disconnect")
251+
}
252+
253+
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
254+
Issue.record("should not report error")
255+
return .failure(.internalServerError)
256+
}
257+
258+
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
259+
Issue.record("should not report init error")
260+
return .failure(.internalServerError)
261+
}
262+
}
263+
264+
struct DisconnectBehavior: LambdaServerBehavior {
265+
func getInvocation() -> GetInvocationResult {
266+
.success(("disconnect", "0"))
267+
}
268+
269+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError> {
270+
.success(nil)
271+
}
272+
273+
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError> {
274+
Issue.record("should not report error")
275+
return .failure(.internalServerError)
276+
}
277+
278+
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError> {
279+
Issue.record("should not report init error")
280+
return .failure(.internalServerError)
281+
}
282+
}
283+
284+
@Test(
285+
"Server closing the connection when waiting for next invocation throws an error",
286+
arguments: [DisconnectBehavior(), DisconnectAfterSendingResponseBehavior()] as [any LambdaServerBehavior]
287+
)
288+
func testChannelCloseFutureWithWaitingForNextInvocation(behavior: LambdaServerBehavior) async throws {
289+
try await withMockServer(behaviour: behavior) { port in
290+
let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port)
291+
292+
try await LambdaRuntimeClient.withRuntimeClient(
293+
configuration: configuration,
294+
eventLoop: NIOSingletons.posixEventLoopGroup.next(),
295+
logger: self.logger
296+
) { runtimeClient in
297+
do {
298+
299+
// simulate traffic until the server reports it has closed the connection
300+
// or a timeout, whichever comes first
301+
// result is ignored here, either there is a connection error or a timeout
302+
let _ = try await timeout(deadline: .seconds(1)) {
303+
while true {
304+
let (_, writer) = try await runtimeClient.nextInvocation()
305+
try await writer.writeAndFinish(ByteBuffer(string: "hello"))
306+
}
307+
}
308+
// result is ignored here, we should never reach this line
309+
Issue.record("Connection reset test did not throw an error")
310+
311+
} catch is CancellationError {
312+
Issue.record("Runtime client did not send connection closed error")
313+
} catch let error as LambdaRuntimeError {
314+
logger.trace("LambdaRuntimeError - expected")
315+
#expect(error.code == .connectionToControlPlaneLost)
316+
} catch let error as ChannelError {
317+
logger.trace("ChannelError - expected")
318+
#expect(error == .ioOnClosedChannel)
319+
} catch let error as IOError {
320+
logger.trace("IOError - expected")
321+
#expect(error.errnoCode == ECONNRESET || error.errnoCode == EPIPE)
322+
} catch {
323+
Issue.record("Unexpected error type: \(error)")
324+
}
325+
}
326+
}
327+
}
241328
}

Tests/AWSLambdaRuntimeTests/MockLambdaServer.swift

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ final class HTTPHandler: ChannelInboundHandler {
160160
var responseStatus: HTTPResponseStatus
161161
var responseBody: String?
162162
var responseHeaders: [(String, String)]?
163+
var disconnectAfterSend = false
163164

164165
// Handle post-init-error first to avoid matching the less specific post-error suffix.
165166
if request.head.uri.hasSuffix(Consts.postInitErrorURL) {
@@ -202,8 +203,11 @@ final class HTTPHandler: ChannelInboundHandler {
202203
behavior.captureHeaders(request.head.headers)
203204

204205
switch behavior.processResponse(requestId: String(requestId), response: requestBody) {
205-
case .success:
206+
case .success(let next):
206207
responseStatus = .accepted
208+
if next == "delayed-disconnect" {
209+
disconnectAfterSend = true
210+
}
207211
case .failure(let error):
208212
responseStatus = .init(statusCode: error.rawValue)
209213
}
@@ -223,14 +227,21 @@ final class HTTPHandler: ChannelInboundHandler {
223227
} else {
224228
responseStatus = .notFound
225229
}
226-
self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody)
230+
self.writeResponse(
231+
context: context,
232+
status: responseStatus,
233+
headers: responseHeaders,
234+
body: responseBody,
235+
closeConnection: disconnectAfterSend
236+
)
227237
}
228238

229239
func writeResponse(
230240
context: ChannelHandlerContext,
231241
status: HTTPResponseStatus,
232242
headers: [(String, String)]? = nil,
233-
body: String? = nil
243+
body: String? = nil,
244+
closeConnection: Bool = false
234245
) {
235246
var headers = HTTPHeaders(headers ?? [])
236247
headers.add(name: "Content-Length", value: "\(body?.utf8.count ?? 0)")
@@ -253,14 +264,19 @@ final class HTTPHandler: ChannelInboundHandler {
253264
}
254265

255266
let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop)
256-
257267
let keepAlive = self.keepAlive
258268
context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in
269+
let context = loopBoundContext.value
270+
if closeConnection {
271+
context.close(promise: nil)
272+
return
273+
}
274+
259275
if case .failure(let error) = result {
260276
logger.error("write error \(error)")
261277
}
278+
262279
if !keepAlive {
263-
let context = loopBoundContext.value
264280
context.close().whenFailure { error in
265281
logger.error("close error \(error)")
266282
}
@@ -271,7 +287,7 @@ final class HTTPHandler: ChannelInboundHandler {
271287

272288
protocol LambdaServerBehavior: Sendable {
273289
func getInvocation() -> GetInvocationResult
274-
func processResponse(requestId: String, response: String?) -> Result<Void, ProcessResponseError>
290+
func processResponse(requestId: String, response: String?) -> Result<String?, ProcessResponseError>
275291
func processError(requestId: String, error: ErrorResponse) -> Result<Void, ProcessErrorError>
276292
func processInitError(error: ErrorResponse) -> Result<Void, ProcessErrorError>
277293

0 commit comments

Comments
 (0)