|
| 1 | +//! Lambda integration. |
| 2 | +
|
| 3 | +use std::convert::Infallible; |
| 4 | +use std::future::Future; |
| 5 | +use std::pin::Pin; |
| 6 | +use std::task::{Context, Poll}; |
| 7 | + |
| 8 | +use aws_lambda_events::encodings::Base64Data; |
| 9 | +use bytes::Bytes; |
| 10 | +use futures::Stream; |
| 11 | +use http::header::CONTENT_TYPE; |
| 12 | +use http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Uri}; |
| 13 | +use http_body_util::{BodyExt, Full}; |
| 14 | +use lambda_runtime::service_fn; |
| 15 | +use lambda_runtime::tower::ServiceExt; |
| 16 | +use lambda_runtime::{FunctionResponse, LambdaEvent}; |
| 17 | +use serde::{Deserialize, Serialize}; |
| 18 | +use tracing::debug; |
| 19 | + |
| 20 | +use crate::endpoint::{Endpoint, Error, HandleOptions, ProtocolMode}; |
| 21 | + |
| 22 | +#[allow(clippy::declare_interior_mutable_const)] |
| 23 | +const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server"); |
| 24 | +const X_RESTATE_SERVER_VALUE: HeaderValue = |
| 25 | + HeaderValue::from_static(concat!("restate-sdk-rust/", env!("CARGO_PKG_VERSION"))); |
| 26 | + |
| 27 | +/// Represents an incoming request from AWS Lambda when using Lambda Function URLs. |
| 28 | +/// |
| 29 | +/// This struct is used to deserialize the JSON payload from Lambda. |
| 30 | +#[doc(hidden)] |
| 31 | +#[derive(Clone, Debug, Default, Deserialize, PartialEq)] |
| 32 | +#[serde(rename_all = "camelCase")] |
| 33 | +pub struct LambdaRequest { |
| 34 | + /// The HTTP method of the request. |
| 35 | + // #[serde(with = "http_method")] |
| 36 | + #[serde(with = "http_serde::method")] |
| 37 | + pub http_method: Method, |
| 38 | + /// The path of the request. |
| 39 | + #[serde(default)] |
| 40 | + #[serde(with = "http_serde::uri")] |
| 41 | + pub path: Uri, |
| 42 | + /// The headers of the request. |
| 43 | + #[serde(with = "http_serde::header_map", default)] |
| 44 | + pub headers: HeaderMap, |
| 45 | + /// Whether the request body is Base64 encoded. |
| 46 | + pub is_base64_encoded: bool, |
| 47 | + /// The request body, if any. |
| 48 | + pub body: Option<Base64Data>, |
| 49 | +} |
| 50 | + |
| 51 | +/// Represents a response to be sent back to AWS Lambda. |
| 52 | +/// |
| 53 | +/// This struct is serialized to JSON to form the response payload for Lambda. |
| 54 | +#[doc(hidden)] |
| 55 | +#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize)] |
| 56 | +#[serde(rename_all = "camelCase")] |
| 57 | +pub struct LambdaResponse { |
| 58 | + /// The HTTP status code. |
| 59 | + pub status_code: u16, |
| 60 | + /// An optional status description. |
| 61 | + #[serde(default)] |
| 62 | + pub status_description: Option<String>, |
| 63 | + /// The response headers. |
| 64 | + #[serde(with = "http_serde::header_map", default)] |
| 65 | + pub headers: HeaderMap, |
| 66 | + /// The optional response body, Base64 encoded. |
| 67 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 68 | + pub body: Option<Base64Data>, |
| 69 | + /// Whether the response body is Base64 encoded. This should generally be `true` |
| 70 | + /// when a body is present. |
| 71 | + #[serde(default)] |
| 72 | + pub is_base64_encoded: bool, |
| 73 | +} |
| 74 | + |
| 75 | +impl LambdaResponse { |
| 76 | + fn builder() -> LambdaResponseBuilder { |
| 77 | + LambdaResponseBuilder { |
| 78 | + status_code: 200, |
| 79 | + status_description: None, |
| 80 | + headers: HeaderMap::default(), |
| 81 | + body: None, |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + fn from_message<M: ToString>(code: u16, message: M) -> Self { |
| 86 | + Self::builder() |
| 87 | + .status_code(code) |
| 88 | + .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE) |
| 89 | + .header(CONTENT_TYPE, "text/plain".parse().unwrap()) |
| 90 | + .body(Bytes::from(message.to_string())) |
| 91 | + .build() |
| 92 | + } |
| 93 | +} |
| 94 | + |
| 95 | +impl From<LambdaResponse> for FunctionResponse<LambdaResponse, ClosedStream> { |
| 96 | + fn from(response: LambdaResponse) -> Self { |
| 97 | + FunctionResponse::BufferedResponse(response) |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +struct LambdaResponseBuilder { |
| 102 | + status_code: u16, |
| 103 | + status_description: Option<String>, |
| 104 | + headers: HeaderMap, |
| 105 | + body: Option<Base64Data>, |
| 106 | +} |
| 107 | + |
| 108 | +impl LambdaResponseBuilder { |
| 109 | + pub fn status_code(mut self, status_code: u16) -> Self { |
| 110 | + self.status_code = status_code; |
| 111 | + self.status_description = http::StatusCode::from_u16(status_code) |
| 112 | + .map(|s| s.to_string()) |
| 113 | + .ok(); |
| 114 | + self |
| 115 | + } |
| 116 | + |
| 117 | + pub fn header(mut self, key: HeaderName, value: HeaderValue) -> Self { |
| 118 | + self.headers.insert(key, value.into()); |
| 119 | + self |
| 120 | + } |
| 121 | + |
| 122 | + pub fn body(mut self, body: Bytes) -> Self { |
| 123 | + self.body = Some(Base64Data(body.into())); |
| 124 | + self |
| 125 | + } |
| 126 | + |
| 127 | + pub fn build(self) -> LambdaResponse { |
| 128 | + LambdaResponse { |
| 129 | + status_code: self.status_code, |
| 130 | + status_description: self.status_description, |
| 131 | + headers: self.headers, |
| 132 | + body: self.body, |
| 133 | + is_base64_encoded: true, |
| 134 | + } |
| 135 | + } |
| 136 | +} |
| 137 | + |
| 138 | +impl From<Error> for LambdaResponse { |
| 139 | + fn from(err: Error) -> Self { |
| 140 | + LambdaResponse::from_message(err.status_code(), err.to_string()) |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +/// A [`Stream`] that is immediately closed. |
| 145 | +/// |
| 146 | +/// This is used as a placeholder body for buffered responses in the Lambda integration, |
| 147 | +/// where the entire response is sent at once and no streaming body is needed. |
| 148 | +#[doc(hidden)] |
| 149 | +pub struct ClosedStream; |
| 150 | +impl Stream for ClosedStream { |
| 151 | + type Item = Result<Bytes, Infallible>; |
| 152 | + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 153 | + Poll::Ready(None) |
| 154 | + } |
| 155 | +} |
| 156 | + |
| 157 | +/// Wraps an [`Endpoint`] to implement the `lambda_runtime::Service` trait for AWS Lambda. |
| 158 | +/// |
| 159 | +/// This adapter allows a Restate endpoint to be deployed as an AWS Lambda function. |
| 160 | +/// It handles the conversion between Lambda's request/response format and the |
| 161 | +/// internal representation used by the SDK. |
| 162 | +#[derive(Clone)] |
| 163 | +pub struct LambdaEndpoint(Endpoint); |
| 164 | + |
| 165 | +impl LambdaEndpoint { |
| 166 | + pub fn new(endpoint: Endpoint) -> Self { |
| 167 | + Self(endpoint) |
| 168 | + } |
| 169 | + |
| 170 | + /// Runs the Lambda service. |
| 171 | + /// |
| 172 | + /// This function starts the `lambda_runtime` and begins processing incoming |
| 173 | + /// Lambda events, passing them to the wrapped [`Endpoint`]. |
| 174 | + pub fn run(self) -> impl Future<Output = Result<(), lambda_runtime::Error>> { |
| 175 | + let svc = service_fn(handle); |
| 176 | + let svc = svc.map_request(move |req| { |
| 177 | + let endpoint = self.0.clone(); |
| 178 | + LambdaEventWithEndpoint { |
| 179 | + inner: req, |
| 180 | + endpoint, |
| 181 | + } |
| 182 | + }); |
| 183 | + |
| 184 | + lambda_runtime::run(svc) |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +struct LambdaEventWithEndpoint { |
| 189 | + inner: LambdaEvent<LambdaRequest>, |
| 190 | + endpoint: Endpoint, |
| 191 | +} |
| 192 | + |
| 193 | +async fn handle(req: LambdaEventWithEndpoint) -> Result<LambdaResponse, Infallible> { |
| 194 | + let (request, _) = req.inner.into_parts(); |
| 195 | + |
| 196 | + let mut http_request = Request::builder() |
| 197 | + .method(request.http_method) |
| 198 | + .uri(request.path) |
| 199 | + .body(request.body.map(|b| Full::from(b.0)).unwrap_or_default()) |
| 200 | + .expect("to build"); |
| 201 | + |
| 202 | + http_request.headers_mut().extend(request.headers); |
| 203 | + |
| 204 | + let response = match req.endpoint.handle_with_options( |
| 205 | + http_request, |
| 206 | + HandleOptions { |
| 207 | + protocol_mode: ProtocolMode::RequestResponse, |
| 208 | + }, |
| 209 | + ) { |
| 210 | + Ok(res) => res, |
| 211 | + Err(err) => { |
| 212 | + debug!("Error when trying to handle incoming request: {err}"); |
| 213 | + return Ok(err.into()); |
| 214 | + } |
| 215 | + }; |
| 216 | + |
| 217 | + let (parts, body) = response.into_parts(); |
| 218 | + // collect the response |
| 219 | + let body = match body.collect().await { |
| 220 | + Ok(body) => body.to_bytes(), |
| 221 | + Err(err) => { |
| 222 | + debug!("Error when trying to collect response body: {err}"); |
| 223 | + return Ok(LambdaResponse::from_message(500, err)); |
| 224 | + } |
| 225 | + }; |
| 226 | + |
| 227 | + let mut builder = LambdaResponse::builder() |
| 228 | + .status_code(parts.status.as_u16()) |
| 229 | + .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE); |
| 230 | + |
| 231 | + builder.headers.extend(parts.headers); |
| 232 | + |
| 233 | + Ok(builder.body(body).build()) |
| 234 | +} |
0 commit comments