Skip to content

Commit 374ed81

Browse files
committed
Lambda-http: vary type of response based on request origin
ApiGatewayV2, ApiGateway and Alb all expect different types of responses to be returned from the invoked lambda function. Thus, it makes sense to pass the request origin to the creation of the response, so that the correct type of LambdaResponse is returned from the function. This commit also adds support for the "cookies" attribute which can be used for returning multiple Set-cookie headers from a lambda invoked via ApiGatewayV2, since ApiGatewayV2 no longer seems to recognize the "multiValueHeaders" attribute. Closes: awslabs#267.
1 parent 13aa8f0 commit 374ed81

File tree

3 files changed

+192
-65
lines changed

3 files changed

+192
-65
lines changed

lambda-http/src/lib.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ pub mod request;
100100
mod response;
101101
mod strmap;
102102
pub use crate::{body::Body, ext::RequestExt, response::IntoResponse, strmap::StrMap};
103-
use crate::{request::LambdaRequest, response::LambdaResponse};
103+
use crate::{
104+
request::{LambdaRequest, RequestOrigin},
105+
response::LambdaResponse,
106+
};
104107
use std::{
105108
future::Future,
106109
pin::Pin,
@@ -149,7 +152,7 @@ where
149152

150153
#[doc(hidden)]
151154
pub struct TransformResponse<R, E> {
152-
is_alb: bool,
155+
request_origin: RequestOrigin,
153156
fut: Pin<Box<dyn Future<Output = Result<R, E>>>>,
154157
}
155158

@@ -160,9 +163,9 @@ where
160163
type Output = Result<LambdaResponse, E>;
161164
fn poll(mut self: Pin<&mut Self>, cx: &mut TaskContext) -> Poll<Self::Output> {
162165
match self.fut.as_mut().poll(cx) {
163-
Poll::Ready(result) => {
164-
Poll::Ready(result.map(|resp| LambdaResponse::from_response(self.is_alb, resp.into_response())))
165-
}
166+
Poll::Ready(result) => Poll::Ready(
167+
result.map(|resp| LambdaResponse::from_response(&self.request_origin, resp.into_response())),
168+
),
166169
Poll::Pending => Poll::Pending,
167170
}
168171
}
@@ -192,8 +195,8 @@ impl<H: Handler> LambdaHandler<LambdaRequest<'_>, LambdaResponse> for Adapter<H>
192195
type Error = H::Error;
193196
type Fut = TransformResponse<H::Response, Self::Error>;
194197
fn call(&mut self, event: LambdaRequest<'_>, context: Context) -> Self::Fut {
195-
let is_alb = event.is_alb();
198+
let request_origin = event.request_origin();
196199
let fut = Box::pin(self.handler.call(event.into(), context));
197-
TransformResponse { is_alb, fut }
200+
TransformResponse { request_origin, fut }
198201
}
199202
}

lambda-http/src/request.rs

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,30 @@ pub enum LambdaRequest<'a> {
9090
}
9191

9292
impl LambdaRequest<'_> {
93-
/// Return true if this request represents an ALB event
94-
///
95-
/// Alb responses have unique requirements for responses that
96-
/// vary only slightly from APIGateway responses. We serialize
97-
/// responses capturing a hint that the request was an alb triggered
98-
/// event.
99-
pub fn is_alb(&self) -> bool {
100-
matches!(self, LambdaRequest::Alb { .. })
93+
/// Return the `RequestOrigin` of the request to determine where the `LambdaRequest`
94+
/// originated from, so that the appropriate response can be selected based on what
95+
/// type of response the request origin expects.
96+
pub fn request_origin(&self) -> RequestOrigin {
97+
match self {
98+
LambdaRequest::ApiGatewayV2 { .. } => RequestOrigin::ApiGatewayV2,
99+
LambdaRequest::Alb { .. } => RequestOrigin::Alb,
100+
LambdaRequest::ApiGateway { .. } => RequestOrigin::ApiGateway,
101+
}
101102
}
102103
}
103104

105+
/// Represents the origin from which the lambda was requested from.
106+
#[doc(hidden)]
107+
#[derive(Debug)]
108+
pub enum RequestOrigin {
109+
/// API Gateway v2 request origin
110+
ApiGatewayV2,
111+
/// API Gateway request origin
112+
ApiGateway,
113+
/// ALB request origin
114+
Alb,
115+
}
116+
104117
#[derive(Deserialize, Debug, Clone)]
105118
#[serde(rename_all = "camelCase")]
106119
pub struct ApiGatewayV2RequestContext {

lambda-http/src/response.rs

Lines changed: 161 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,70 @@
11
//! Response types
22
33
use http::{
4-
header::{HeaderMap, HeaderValue, CONTENT_TYPE},
4+
header::{HeaderMap, HeaderValue, CONTENT_TYPE, SET_COOKIE},
55
Response,
66
};
77
use serde::{
8-
ser::{Error as SerError, SerializeMap},
8+
ser::{Error as SerError, SerializeMap, SerializeSeq},
99
Serializer,
1010
};
1111
use serde_derive::Serialize;
1212

1313
use crate::body::Body;
14+
use crate::request::RequestOrigin;
1415

15-
/// Representation of API Gateway response
16+
/// Representation of Lambda response
17+
#[doc(hidden)]
18+
#[derive(Serialize, Debug)]
19+
#[serde(untagged)]
20+
pub enum LambdaResponse {
21+
ApiGatewayV2(ApiGatewayV2Response),
22+
Alb(AlbResponse),
23+
ApiGateway(ApiGatewayResponse),
24+
}
25+
26+
/// Representation of API Gateway v2 lambda response
1627
#[doc(hidden)]
1728
#[derive(Serialize, Debug)]
1829
#[serde(rename_all = "camelCase")]
19-
pub struct LambdaResponse {
20-
pub status_code: u16,
21-
// ALB requires a statusDescription i.e. "200 OK" field but API Gateway returns an error
22-
// when one is provided. only populate this for ALB responses
30+
pub struct ApiGatewayV2Response {
31+
status_code: u16,
32+
#[serde(serialize_with = "serialize_headers")]
33+
headers: HeaderMap<HeaderValue>,
34+
#[serde(serialize_with = "serialize_headers_slice")]
35+
cookies: Vec<HeaderValue>,
2336
#[serde(skip_serializing_if = "Option::is_none")]
24-
pub status_description: Option<String>,
37+
body: Option<Body>,
38+
is_base64_encoded: bool,
39+
}
40+
41+
/// Representation of ALB lambda response
42+
#[doc(hidden)]
43+
#[derive(Serialize, Debug)]
44+
#[serde(rename_all = "camelCase")]
45+
pub struct AlbResponse {
46+
status_code: u16,
47+
status_description: String,
2548
#[serde(serialize_with = "serialize_headers")]
26-
pub headers: HeaderMap<HeaderValue>,
27-
#[serde(serialize_with = "serialize_multi_value_headers")]
28-
pub multi_value_headers: HeaderMap<HeaderValue>,
49+
headers: HeaderMap<HeaderValue>,
2950
#[serde(skip_serializing_if = "Option::is_none")]
30-
pub body: Option<Body>,
31-
// This field is optional for API Gateway but required for ALB
32-
pub is_base64_encoded: bool,
51+
body: Option<Body>,
52+
is_base64_encoded: bool,
3353
}
3454

35-
#[cfg(test)]
36-
impl Default for LambdaResponse {
37-
fn default() -> Self {
38-
Self {
39-
status_code: 200,
40-
status_description: Default::default(),
41-
headers: Default::default(),
42-
multi_value_headers: Default::default(),
43-
body: Default::default(),
44-
is_base64_encoded: Default::default(),
45-
}
46-
}
55+
/// Representation of API Gateway lambda response
56+
#[doc(hidden)]
57+
#[derive(Serialize, Debug)]
58+
#[serde(rename_all = "camelCase")]
59+
pub struct ApiGatewayResponse {
60+
status_code: u16,
61+
#[serde(serialize_with = "serialize_headers")]
62+
headers: HeaderMap<HeaderValue>,
63+
#[serde(serialize_with = "serialize_multi_value_headers")]
64+
multi_value_headers: HeaderMap<HeaderValue>,
65+
#[serde(skip_serializing_if = "Option::is_none")]
66+
body: Option<Body>,
67+
is_base64_encoded: bool,
4768
}
4869

4970
/// Serialize a http::HeaderMap into a serde str => str map
@@ -75,9 +96,21 @@ where
7596
map.end()
7697
}
7798

99+
/// Serialize a &[HeaderValue] into a Vec<str>
100+
fn serialize_headers_slice<S>(headers: &[HeaderValue], serializer: S) -> Result<S::Ok, S::Error>
101+
where
102+
S: Serializer,
103+
{
104+
let mut seq = serializer.serialize_seq(Some(headers.len()))?;
105+
for header in headers {
106+
seq.serialize_element(header.to_str().map_err(S::Error::custom)?)?;
107+
}
108+
seq.end()
109+
}
110+
78111
/// tranformation from http type to internal type
79112
impl LambdaResponse {
80-
pub(crate) fn from_response<T>(is_alb: bool, value: Response<T>) -> Self
113+
pub(crate) fn from_response<T>(request_origin: &RequestOrigin, value: Response<T>) -> Self
81114
where
82115
T: Into<Body>,
83116
{
@@ -87,21 +120,43 @@ impl LambdaResponse {
87120
b @ Body::Text(_) => (false, Some(b)),
88121
b @ Body::Binary(_) => (true, Some(b)),
89122
};
90-
Self {
91-
status_code: parts.status.as_u16(),
92-
status_description: if is_alb {
93-
Some(format!(
123+
124+
let mut headers = parts.headers;
125+
let status_code = parts.status.as_u16();
126+
127+
match request_origin {
128+
RequestOrigin::ApiGatewayV2 => {
129+
// ApiGatewayV2 expects the set-cookies headers to be in the "cookies" attribute,
130+
// so remove them from the headers.
131+
let cookies: Vec<HeaderValue> = headers.get_all(SET_COOKIE).iter().cloned().collect();
132+
headers.remove(SET_COOKIE);
133+
134+
LambdaResponse::ApiGatewayV2(ApiGatewayV2Response {
135+
body,
136+
status_code,
137+
is_base64_encoded,
138+
cookies,
139+
headers,
140+
})
141+
}
142+
RequestOrigin::ApiGateway => LambdaResponse::ApiGateway(ApiGatewayResponse {
143+
body,
144+
status_code,
145+
is_base64_encoded,
146+
headers: headers.clone(),
147+
multi_value_headers: headers,
148+
}),
149+
RequestOrigin::Alb => LambdaResponse::Alb(AlbResponse {
150+
body,
151+
status_code,
152+
is_base64_encoded,
153+
headers,
154+
status_description: format!(
94155
"{} {}",
95-
parts.status.as_u16(),
156+
status_code,
96157
parts.status.canonical_reason().unwrap_or_default()
97-
))
98-
} else {
99-
None
100-
},
101-
body,
102-
headers: parts.headers.clone(),
103-
multi_value_headers: parts.headers,
104-
is_base64_encoded,
158+
),
159+
}),
105160
}
106161
}
107162
}
@@ -161,10 +216,42 @@ impl IntoResponse for serde_json::Value {
161216

162217
#[cfg(test)]
163218
mod tests {
164-
use super::{Body, IntoResponse, LambdaResponse};
219+
use super::{
220+
AlbResponse, ApiGatewayResponse, ApiGatewayV2Response, Body, IntoResponse, LambdaResponse, RequestOrigin,
221+
};
165222
use http::{header::CONTENT_TYPE, Response};
166223
use serde_json::{self, json};
167224

225+
fn api_gateway_response() -> ApiGatewayResponse {
226+
ApiGatewayResponse {
227+
status_code: 200,
228+
headers: Default::default(),
229+
multi_value_headers: Default::default(),
230+
body: Default::default(),
231+
is_base64_encoded: Default::default(),
232+
}
233+
}
234+
235+
fn alb_response() -> AlbResponse {
236+
AlbResponse {
237+
status_code: 200,
238+
status_description: "200 OK".to_string(),
239+
headers: Default::default(),
240+
body: Default::default(),
241+
is_base64_encoded: Default::default(),
242+
}
243+
}
244+
245+
fn api_gateway_v2_response() -> ApiGatewayV2Response {
246+
ApiGatewayV2Response {
247+
status_code: 200,
248+
headers: Default::default(),
249+
body: Default::default(),
250+
cookies: Default::default(),
251+
is_base64_encoded: Default::default(),
252+
}
253+
}
254+
168255
#[test]
169256
fn json_into_response() {
170257
let response = json!({ "hello": "lambda"}).into_response();
@@ -191,32 +278,39 @@ mod tests {
191278
}
192279

193280
#[test]
194-
fn default_response() {
195-
assert_eq!(LambdaResponse::default().status_code, 200)
281+
fn serialize_body_for_api_gateway() {
282+
let mut resp = api_gateway_response();
283+
resp.body = Some("foo".into());
284+
assert_eq!(
285+
serde_json::to_string(&resp).expect("failed to serialize response"),
286+
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
287+
);
196288
}
197289

198290
#[test]
199-
fn serialize_default() {
291+
fn serialize_body_for_alb() {
292+
let mut resp = alb_response();
293+
resp.body = Some("foo".into());
200294
assert_eq!(
201-
serde_json::to_string(&LambdaResponse::default()).expect("failed to serialize response"),
202-
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"isBase64Encoded":false}"#
295+
serde_json::to_string(&resp).expect("failed to serialize response"),
296+
r#"{"statusCode":200,"statusDescription":"200 OK","headers":{},"body":"foo","isBase64Encoded":false}"#
203297
);
204298
}
205299

206300
#[test]
207-
fn serialize_body() {
208-
let mut resp = LambdaResponse::default();
301+
fn serialize_body_for_api_gateway_v2() {
302+
let mut resp = api_gateway_v2_response();
209303
resp.body = Some("foo".into());
210304
assert_eq!(
211305
serde_json::to_string(&resp).expect("failed to serialize response"),
212-
r#"{"statusCode":200,"headers":{},"multiValueHeaders":{},"body":"foo","isBase64Encoded":false}"#
306+
r#"{"statusCode":200,"headers":{},"cookies":[],"body":"foo","isBase64Encoded":false}"#
213307
);
214308
}
215309

216310
#[test]
217311
fn serialize_multi_value_headers() {
218312
let res = LambdaResponse::from_response(
219-
false,
313+
&RequestOrigin::ApiGateway,
220314
Response::builder()
221315
.header("multi", "a")
222316
.header("multi", "b")
@@ -229,4 +323,21 @@ mod tests {
229323
r#"{"statusCode":200,"headers":{"multi":"a"},"multiValueHeaders":{"multi":["a","b"]},"isBase64Encoded":false}"#
230324
)
231325
}
326+
327+
#[test]
328+
fn serialize_cookies() {
329+
let res = LambdaResponse::from_response(
330+
&RequestOrigin::ApiGatewayV2,
331+
Response::builder()
332+
.header("set-cookie", "cookie1=a")
333+
.header("set-cookie", "cookie2=b")
334+
.body(Body::from(()))
335+
.expect("failed to create response"),
336+
);
337+
let json = serde_json::to_string(&res).expect("failed to serialize to json");
338+
assert_eq!(
339+
json,
340+
r#"{"statusCode":200,"headers":{},"cookies":["cookie1=a","cookie2=b"],"isBase64Encoded":false}"#
341+
)
342+
}
232343
}

0 commit comments

Comments
 (0)