Skip to content

Commit 5556ab2

Browse files
l3kubahildebrand
andauthored
Lambda-http: vary type of response based on request origin (#269)
* Lambda-http: vary type of response based on request origin ApiGatewayV2, ApiGateway and Alb all expect different types of responses to be returned from the invoked lambda function. Thus, it makes sense to pass the request origin to the creation of the response, so that the correct type of LambdaResponse is returned from the function. This commit also adds support for the "cookies" attribute which can be used for returning multiple Set-cookie headers from a lambda invoked via ApiGatewayV2, since ApiGatewayV2 no longer seems to recognize the "multiValueHeaders" attribute. Closes: #267. * Fix Serialize import * Fix missing reference on self * Fix import order * Add missing comma for fmt check Co-authored-by: Blake Hildebrand <[email protected]>
1 parent 6033ce3 commit 5556ab2

File tree

3 files changed

+193
-66
lines changed

3 files changed

+193
-66
lines changed

lambda-http/src/lib.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,10 @@ pub mod request;
7575
mod response;
7676
mod strmap;
7777
pub use crate::{body::Body, ext::RequestExt, response::IntoResponse, strmap::StrMap};
78-
use crate::{request::LambdaRequest, response::LambdaResponse};
78+
use crate::{
79+
request::{LambdaRequest, RequestOrigin},
80+
response::LambdaResponse,
81+
};
7982
use std::{
8083
future::Future,
8184
pin::Pin,
@@ -124,7 +127,7 @@ where
124127

125128
#[doc(hidden)]
126129
pub struct TransformResponse<R, E> {
127-
is_alb: bool,
130+
request_origin: RequestOrigin,
128131
fut: Pin<Box<dyn Future<Output = Result<R, E>> + Send + Sync>>,
129132
}
130133

@@ -135,9 +138,9 @@ where
135138
type Output = Result<LambdaResponse, E>;
136139
fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext) -> Poll<Self::Output> {
137140
match self.fut.as_mut().poll(cx) {
138-
Poll::Ready(result) => {
139-
Poll::Ready(result.map(|resp| LambdaResponse::from_response(self.is_alb, resp.into_response())))
140-
}
141+
Poll::Ready(result) => Poll::Ready(
142+
result.map(|resp| LambdaResponse::from_response(&self.request_origin, resp.into_response())),
143+
),
141144
Poll::Pending => Poll::Pending,
142145
}
143146
}
@@ -166,9 +169,10 @@ impl<H: Handler> Handler for Adapter<H> {
166169
impl<H: Handler> LambdaHandler<LambdaRequest<'_>, LambdaResponse> for Adapter<H> {
167170
type Error = H::Error;
168171
type Fut = TransformResponse<H::Response, Self::Error>;
172+
169173
fn call(&self, event: LambdaRequest<'_>, context: Context) -> Self::Fut {
170-
let is_alb = event.is_alb();
174+
let request_origin = event.request_origin();
171175
let fut = Box::pin(self.handler.call(event.into(), context));
172-
TransformResponse { is_alb, fut }
176+
TransformResponse { request_origin, fut }
173177
}
174178
}

lambda-http/src/request.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,30 @@ pub enum LambdaRequest<'a> {
9191
}
9292

9393
impl LambdaRequest<'_> {
94-
/// Return true if this request represents an ALB event
95-
///
96-
/// Alb responses have unique requirements for responses that
97-
/// vary only slightly from APIGateway responses. We serialize
98-
/// responses capturing a hint that the request was an alb triggered
99-
/// event.
100-
pub fn is_alb(&self) -> bool {
101-
matches!(self, LambdaRequest::Alb { .. })
94+
/// Return the `RequestOrigin` of the request to determine where the `LambdaRequest`
95+
/// originated from, so that the appropriate response can be selected based on what
96+
/// type of response the request origin expects.
97+
pub fn request_origin(&self) -> RequestOrigin {
98+
match self {
99+
LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2,
100+
LambdaRequest::Alb { .. } => RequestOrigin::Alb,
101+
LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway,
102+
}
102103
}
103104
}
104105

106+
/// Represents the origin from which the lambda was requested from.
107+
#[doc(hidden)]
108+
#[derive(Debug)]
109+
pub enum RequestOrigin {
110+
/// API Gateway v2 request origin
111+
ApiGatewayV2,
112+
/// API Gateway request origin
113+
ApiGateway,
114+
/// ALB request origin
115+
Alb,
116+
}
117+
105118
/// See [context-variable-reference](https://docs.aws.amazon.com/apigateway/latest/developerguide/api-gateway-mapping-template-reference.html) for more detail.
106119
#[derive(Deserialize, Debug, Clone)]
107120
#[serde(rename_all = "camelCase")]

lambda-http/src/response.rs

Lines changed: 161 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,67 @@
11
//! Response types
22
3-
use crate::body::Body;
3+
use crate::{body::Body, request::RequestOrigin};
44
use http::{
5-
header::{HeaderMap, HeaderValue, CONTENT_TYPE},
5+
header::{HeaderMap, HeaderValue, CONTENT_TYPE, SET_COOKIE},
66
Response,
77
};
88
use serde::{
9-
ser::{Error as SerError, SerializeMap},
9+
ser::{Error as SerError, SerializeMap, SerializeSeq},
1010
Serialize, Serializer,
1111
};
1212

13-
/// Representation of API Gateway response
13+
/// Representation of Lambda response
14+
#[doc(hidden)]
15+
#[derive(Serialize, Debug)]
16+
#[serde(untagged)]
17+
pub enum LambdaResponse {
18+
ApiGatewayV2(ApiGatewayV2Response),
19+
Alb(AlbResponse),
20+
ApiGateway(ApiGatewayResponse),
21+
}
22+
23+
/// Representation of API Gateway v2 lambda response
1424
#[doc(hidden)]
1525
#[derive(Serialize, Debug)]
1626
#[serde(rename_all = "camelCase")]
17-
pub struct LambdaResponse {
18-
pub status_code: u16,
19-
// ALB requires a statusDescription i.e. "200 OK" field but API Gateway returns an error
20-
// when one is provided. only populate this for ALB responses
27+
pub struct ApiGatewayV2Response {
28+
status_code: u16,
29+
#[serde(serialize_with = "serialize_headers")]
30+
headers: HeaderMap<HeaderValue>,
31+
#[serde(serialize_with = "serialize_headers_slice")]
32+
cookies: Vec<HeaderValue>,
2133
#[serde(skip_serializing_if = "Option::is_none")]
22-
pub status_description: Option<String>,
34+
body: Option<Body>,
35+
is_base64_encoded: bool,
36+
}
37+
38+
/// Representation of ALB lambda response
39+
#[doc(hidden)]
40+
#[derive(Serialize, Debug)]
41+
#[serde(rename_all = "camelCase")]
42+
pub struct AlbResponse {
43+
status_code: u16,
44+
status_description: String,
2345
#[serde(serialize_with = "serialize_headers")]
24-
pub headers: HeaderMap<HeaderValue>,
25-
#[serde(serialize_with = "serialize_multi_value_headers")]
26-
pub multi_value_headers: HeaderMap<HeaderValue>,
46+
headers: HeaderMap<HeaderValue>,
2747
#[serde(skip_serializing_if = "Option::is_none")]
28-
pub body: Option<Body>,
29-
// This field is optional for API Gateway but required for ALB
30-
pub is_base64_encoded: bool,
48+
body: Option<Body>,
49+
is_base64_encoded: bool,
3150
}
3251

33-
#[cfg(test)]
34-
impl Default for LambdaResponse {
35-
fn default() -> Self {
36-
Self {
37-
status_code: 200,
38-
status_description: Default::default(),
39-
headers: Default::default(),
40-
multi_value_headers: Default::default(),
41-
body: Default::default(),
42-
is_base64_encoded: Default::default(),
43-
}
44-
}
52+
/// Representation of API Gateway lambda response
53+
#[doc(hidden)]
54+
#[derive(Serialize, Debug)]
55+
#[serde(rename_all = "camelCase")]
56+
pub struct ApiGatewayResponse {
57+
status_code: u16,
58+
#[serde(serialize_with = "serialize_headers")]
59+
headers: HeaderMap<HeaderValue>,
60+
#[serde(serialize_with = "serialize_multi_value_headers")]
61+
multi_value_headers: HeaderMap<HeaderValue>,
62+
#[serde(skip_serializing_if = "Option::is_none")]
63+
body: Option<Body>,
64+
is_base64_encoded: bool,
4565
}
4666

4767
/// Serialize a http::HeaderMap into a serde str => str map
@@ -73,9 +93,21 @@ where
7393
map.end()
7494
}
7595

96+
/// Serialize a &[HeaderValue] into a Vec<str>
97+
fn serialize_headers_slice<S>(headers: &[HeaderValue], serializer: S) -> Result<S::Ok, S::Error>
98+
where
99+
S: Serializer,
100+
{
101+
let mut seq = serializer.serialize_seq(Some(headers.len()))?;
102+
for header in headers {
103+
seq.serialize_element(header.to_str().map_err(S::Error::custom)?)?;
104+
}
105+
seq.end()
106+
}
107+
76108
/// tranformation from http type to internal type
77109
impl LambdaResponse {
78-
pub(crate) fn from_response<T>(is_alb: bool, value: Response<T>) -> Self
110+
pub(crate) fn from_response<T>(request_origin: &RequestOrigin, value: Response<T>) -> Self
79111
where
80112
T: Into<Body>,
81113
{
@@ -85,21 +117,43 @@ impl LambdaResponse {
85117
b @ Body::Text(_) => (false, Some(b)),
86118
b @ Body::Binary(_) => (true, Some(b)),
87119
};
88-
Self {
89-
status_code: parts.status.as_u16(),
90-
status_description: if is_alb {
91-
Some(format!(
120+
121+
let mut headers = parts.headers;
122+
let status_code = parts.status.as_u16();
123+
124+
match request_origin {
125+
RequestOrigin::ApiGatewayV2 => {
126+
// ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
127+
// so remove them from the headers.
128+
let cookies: Vec<HeaderValue> = headers.get_all(SET_COOKIE).iter().cloned().collect();
129+
headers.remove(SET_COOKIE);
130+
131+
LambdaResponse::ApiGatewayV2(ApiGatewayV2Response {
132+
body,
133+
status_code,
134+
is_base64_encoded,
135+
cookies,
136+
headers,
137+
})
138+
}
139+
RequestOrigin::ApiGateway => LambdaResponse::ApiGateway(ApiGatewayResponse {
140+
body,
141+
status_code,
142+
is_base64_encoded,
143+
headers: headers.clone(),
144+
multi_value_headers: headers,
145+
}),
146+
RequestOrigin::Alb => LambdaResponse::Alb(AlbResponse {
147+
body,
148+
status_code,
149+
is_base64_encoded,
150+
headers,
151+
status_description: format!(
92152
"{} {}",
93-
parts.status.as_u16(),
153+
status_code,
94154
parts.status.canonical_reason().unwrap_or_default()
95-
))
96-
} else {
97-
None
98-
},
99-
body,
100-
headers: parts.headers.clone(),
101-
multi_value_headers: parts.headers,
102-
is_base64_encoded,
155+
),
156+
}),
103157
}
104158
}
105159
}
@@ -159,10 +213,42 @@ impl IntoResponse for serde_json::Value {
159213

160214
#[cfg(test)]
161215
mod tests {
162-
use super::{Body, IntoResponse, LambdaResponse};
216+
use super::{
217+
AlbResponse, ApiGatewayResponse, ApiGatewayV2Response, Body, IntoResponse, LambdaResponse, RequestOrigin,
218+
};
163219
use http::{header::CONTENT_TYPE, Response};
164220
use serde_json::{self, json};
165221

222+
fn api_gateway_response() -> ApiGatewayResponse {
223+
ApiGatewayResponse {
224+
status_code: 200,
225+
headers: Default::default(),
226+
multi_value_headers: Default::default(),
227+
body: Default::default(),
228+
is_base64_encoded: Default::default(),
229+
}
230+
}
231+
232+
fn alb_response() -> AlbResponse {
233+
AlbResponse {
234+
status_code: 200,
235+
status_description: "200 OK".to_string(),
236+
headers: Default::default(),
237+
body: Default::default(),
238+
is_base64_encoded: Default::default(),
239+
}
240+
}
241+
242+
fn api_gateway_v2_response() -> ApiGatewayV2Response {
243+
ApiGatewayV2Response {
244+
status_code: 200,
245+
headers: Default::default(),
246+
body: Default::default(),
247+
cookies: Default::default(),
248+
is_base64_encoded: Default::default(),
249+
}
250+
}
251+
166252
#[test]
167253
fn json_into_response() {
168254
let response = json!({ "hello": "lambda"}).into_response();
@@ -189,32 +275,39 @@ mod tests {
189275
}
190276

191277
#[test]
192-
fn default_response() {
193-
assert_eq!(LambdaResponse::default().status_code, 200)
278+
fn serialize_body_for_api_gateway() {
279+
let mut resp = api_gateway_response();
280+
resp.body = Some("foo".into());
281+
assert_eq!(
282+
serde_json::to_string(&resp).expect("failed to serialize response"),
283+
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
284+
);
194285
}
195286

196287
#[test]
197-
fn serialize_default() {
288+
fn serialize_body_for_alb() {
289+
let mut resp = alb_response();
290+
resp.body = Some("foo".into());
198291
assert_eq!(
199-
serde_json::to_string(&LambdaResponse::default()).expect("failed to serialize response"),
200-
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"isBase64Encoded":false}"#
292+
serde_json::to_string(&resp).expect("failed to serialize response"),
293+
r#"{"statusCode":200,"statusDescription":"200 OK","headers":{},"body":"foo","isBase64Encoded":false}"#
201294
);
202295
}
203296

204297
#[test]
205-
fn serialize_body() {
206-
let mut resp = LambdaResponse::default();
298+
fn serialize_body_for_api_gateway_v2() {
299+
let mut resp = api_gateway_v2_response();
207300
resp.body = Some("foo".into());
208301
assert_eq!(
209302
serde_json::to_string(&resp).expect("failed to serialize response"),
210-
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
303+
r#"{"statusCode":200,"headers":{},"cookies":[],"body":"foo","isBase64Encoded":false}"#
211304
);
212305
}
213306

214307
#[test]
215308
fn serialize_multi_value_headers() {
216309
let res = LambdaResponse::from_response(
217-
false,
310+
&RequestOrigin::ApiGateway,
218311
Response::builder()
219312
.header("multi", "a")
220313
.header("multi", "b")
@@ -227,4 +320,21 @@ mod tests {
227320
r#"{"statusCode":200,"headers":{"multi":"a"},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
228321
)
229322
}
323+
324+
#[test]
325+
fn serialize_cookies() {
326+
let res = LambdaResponse::from_response(
327+
&RequestOrigin::ApiGatewayV2,
328+
Response::builder()
329+
.header("set-cookie", "cookie1=a")
330+
.header("set-cookie", "cookie2=b")
331+
.body(Body::from(()))
332+
.expect("failed to create response"),
333+
);
334+
let json = serde_json::to_string(&res).expect("failed to serialize to json");
335+
assert_eq!(
336+
json,
337+
r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"#
338+
)
339+
}
230340
}

0 commit comments

Comments
 (0)