diff --git a/Cargo.toml b/Cargo.toml index 291345b5..01bc46b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,7 @@ [workspace] members = [ "lambda-http", - "lambda-runtime" + "lambda-runtime-api-client", + "lambda-runtime", + "lambda-extension" ] \ No newline at end of file diff --git a/README.md b/README.md index a9616f1a..96f24476 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,9 @@ This package makes it easy to run AWS Lambda Functions written in Rust. This wor - [![Docs](https://docs.rs/lambda_runtime/badge.svg)](https://docs.rs/lambda_runtime) **`lambda-runtime`** is a library that provides a Lambda runtime for applications written in Rust. - [![Docs](https://docs.rs/lambda_http/badge.svg)](https://docs.rs/lambda_http) **`lambda-http`** is a library that makes it easy to write API Gateway proxy event focused Lambda functions in Rust. +- [![Docs](https://docs.rs/lambda_extension/badge.svg)](https://docs.rs/lambda_extension) **`lambda-extension`** is a library that makes it easy to write Lambda Runtime Extensions in Rust. +- [![Docs](https://docs.rs/lambda_runtime_api_client/badge.svg)](https://docs.rs/lambda_runtime_api_client) **`lambda-runtime-api-client`** is a shared library between the lambda runtime and lambda extension libraries that includes a common API client to talk with the AWS Lambda Runtime API. + ## Example function diff --git a/lambda-extension/Cargo.toml b/lambda-extension/Cargo.toml new file mode 100644 index 00000000..05f1d6f7 --- /dev/null +++ b/lambda-extension/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "lambda_extension" +version = "0.1.0" +edition = "2021" +authors = ["David Calavera "] +description = "AWS Lambda Extension API" +license = "Apache-2.0" +repository = "https://github.com/awslabs/aws-lambda-rust-runtime" +categories = ["web-programming::http-server"] +keywords = ["AWS", "Lambda", "API"] +readme = "README.md" + +[dependencies] +tokio = { version = "1.0", features = ["macros", "io-util", "sync", "rt-multi-thread"] } +hyper = { version = "0.14", features = ["http1", "client", "server", "stream", "runtime"] } +serde = { version = "1", features = ["derive"] } +serde_json = "^1" +bytes = "1.0" +http = "0.2" +async-stream = "0.3" +tracing = { version = "0.1", features = ["log"] } +tower-service = "0.3" +tokio-stream = "0.1.2" +lambda_runtime_api_client = { version = "0.4", path = "../lambda-runtime-api-client" } + +[dev-dependencies] +tracing-subscriber = "0.3" +once_cell = "1.4.0" +simple_logger = "1.6.0" +log = "^0.4" +simple-error = "0.2" diff --git a/lambda-extension/README.md b/lambda-extension/README.md new file mode 100644 index 00000000..4982779f --- /dev/null +++ b/lambda-extension/README.md @@ -0,0 +1,58 @@ +# Runtime Extensions for AWS Lambda in Rust + +[![Docs](https://docs.rs/lambda_extension/badge.svg)](https://docs.rs/lambda_extension) + +**`lambda-extension`** is a library that makes it easy to write [AWS Lambda Runtime Extensions](https://docs.aws.amazon.com/lambda/latest/dg/using-extensions.html) in Rust. + +## Example extension + +The code below creates a simple extension that's registered to every `INVOKE` and `SHUTDOWN` events, and logs them in CloudWatch. + +```rust,no_run +use lambda_extension::{extension_fn, Error, NextEvent}; +use log::LevelFilter; +use simple_logger::SimpleLogger; +use tracing::info; + +async fn log_extension(event: NextEvent) -> Result<(), Error> { + match event { + NextEvent::Shutdown(event) => { + info!("{}", event); + } + NextEvent::Invoke(event) => { + info!("{}", event); + } + } + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap(); + + let func = extension_fn(log_extension); + lambda_extension::run(func).await +} +``` + +## Deployment + +Lambda extensions can be added to your functions either using [Lambda layers](https://docs.aws.amazon.com/lambda/latest/dg/using-extensions.html#using-extensions-config), or adding them to [containers images](https://docs.aws.amazon.com/lambda/latest/dg/using-extensions.html#invocation-extensions-images). + +Regardless of how you deploy them, the extensions MUST be compiled against the same architecture that your lambda functions runs on. + +### Building extensions + +Once you've decided which target you'll use, you can install it by running the next `rustup` command: + +```bash +$ rustup target add x86_64-unknown-linux-musl +``` + +Then, you can compile the extension against that target: + +```bash +$ cargo build -p lambda_extension --example basic --release --target x86_64-unknown-linux-musl +``` + +This previous command will generate a binary file in `target/x86_64-unknown-linux-musl/release/examples` called `basic`. When the extension is registered with the [Runtime Extensions API](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-extensions-api.html#runtimes-extensions-api-reg), that's the name that the extension will be registered with. If you want to register the extension with a different name, you only have to rename this binary file and deploy it with the new name. \ No newline at end of file diff --git a/lambda-extension/examples/basic.rs b/lambda-extension/examples/basic.rs new file mode 100644 index 00000000..573b3281 --- /dev/null +++ b/lambda-extension/examples/basic.rs @@ -0,0 +1,25 @@ +use lambda_extension::{extension_fn, Error, NextEvent}; +use log::LevelFilter; +use simple_logger::SimpleLogger; + +async fn my_extension(event: NextEvent) -> Result<(), Error> { + match event { + NextEvent::Shutdown(_e) => { + // do something with the shutdown event + } + NextEvent::Invoke(_e) => { + // do something with the invoke event + } + } + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + // can be replaced with any other method of initializing `log` + SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap(); + + let func = extension_fn(my_extension); + lambda_extension::run(func).await +} diff --git a/lambda-extension/examples/custom_events.rs b/lambda-extension/examples/custom_events.rs new file mode 100644 index 00000000..88f040aa --- /dev/null +++ b/lambda-extension/examples/custom_events.rs @@ -0,0 +1,30 @@ +use lambda_extension::{extension_fn, Error, NextEvent, Runtime}; +use log::LevelFilter; +use simple_logger::SimpleLogger; + +async fn my_extension(event: NextEvent) -> Result<(), Error> { + match event { + NextEvent::Shutdown(_e) => { + // do something with the shutdown event + } + _ => { + // ignore any other event + // because we've registered the extension + // only to receive SHUTDOWN events + } + } + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + // can be replaced with any other method of initializing `log` + SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap(); + + let func = extension_fn(my_extension); + + let runtime = Runtime::builder().with_events(&["SHUTDOWN"]).register().await?; + + runtime.run(func).await +} diff --git a/lambda-extension/examples/custom_trait_implementation.rs b/lambda-extension/examples/custom_trait_implementation.rs new file mode 100644 index 00000000..caef7730 --- /dev/null +++ b/lambda-extension/examples/custom_trait_implementation.rs @@ -0,0 +1,36 @@ +use lambda_extension::{run, Error, Extension, InvokeEvent, NextEvent}; +use log::LevelFilter; +use simple_logger::SimpleLogger; +use std::{ + future::{ready, Future}, + pin::Pin, +}; + +#[derive(Default)] +struct MyExtension { + data: Vec, +} + +impl Extension for MyExtension { + type Fut = Pin>>>; + fn call(&mut self, event: NextEvent) -> Self::Fut { + match event { + NextEvent::Shutdown(_e) => { + self.data.clear(); + } + NextEvent::Invoke(e) => { + self.data.push(e); + } + } + Box::pin(ready(Ok(()))) + } +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + // required to enable CloudWatch error logging by the runtime + // can be replaced with any other method of initializing `log` + SimpleLogger::new().with_level(LevelFilter::Info).init().unwrap(); + + run(MyExtension::default()).await +} diff --git a/lambda-extension/src/lib.rs b/lambda-extension/src/lib.rs new file mode 100644 index 00000000..41f56890 --- /dev/null +++ b/lambda-extension/src/lib.rs @@ -0,0 +1,257 @@ +#![deny(clippy::all, clippy::cargo)] +#![warn(missing_docs, nonstandard_style, rust_2018_idioms)] + +//! This module includes utilities to create Lambda Runtime Extensions. +//! +//! Create a type that conforms to the [`Extension`] trait. This type can then be passed +//! to the the `lambda_extension::run` function, which launches and runs the Lambda runtime extension. +use hyper::client::{connect::Connection, HttpConnector}; +use lambda_runtime_api_client::Client; +use serde::Deserialize; +use std::future::Future; +use std::path::PathBuf; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_stream::StreamExt; +use tower_service::Service; +use tracing::trace; + +/// Include several request builders to interact with the Extension API. +pub mod requests; + +/// Error type that extensions may result in +pub type Error = lambda_runtime_api_client::Error; + +/// Simple error that encapsulates human readable descriptions +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ExtensionError { + err: String, +} + +impl ExtensionError { + fn boxed>(str: T) -> Box { + Box::new(ExtensionError { err: str.into() }) + } +} + +impl std::fmt::Display for ExtensionError { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.err.fmt(f) + } +} + +impl std::error::Error for ExtensionError {} + +/// Request tracing information +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Tracing { + /// The type of tracing exposed to the extension + pub r#type: String, + /// The span value + pub value: String, +} + +/// Event received when there is a new Lambda invocation. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InvokeEvent { + /// The time that the function times out + pub deadline_ms: u64, + /// The ID assigned to the Lambda request + pub request_id: String, + /// The function's Amazon Resource Name + pub invoked_function_arn: String, + /// The request tracing information + pub tracing: Tracing, +} + +/// Event received when a Lambda function shuts down. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownEvent { + /// The reason why the function terminates + /// It can be SPINDOWN, TIMEOUT, or FAILURE + pub shutdown_reason: String, + /// The time that the function times out + pub deadline_ms: u64, +} + +/// Event that the extension receives in +/// either the INVOKE or SHUTDOWN phase +#[derive(Debug, Deserialize)] +#[serde(rename_all = "UPPERCASE", tag = "eventType")] +pub enum NextEvent { + /// Payload when the event happens in the INVOKE phase + Invoke(InvokeEvent), + /// Payload when the event happens in the SHUTDOWN phase + Shutdown(ShutdownEvent), +} + +impl NextEvent { + fn is_invoke(&self) -> bool { + matches!(self, NextEvent::Invoke(_)) + } +} + +/// A trait describing an asynchronous extension. +pub trait Extension { + /// Response of this Extension. + type Fut: Future>; + /// Handle the incoming event. + fn call(&mut self, event: NextEvent) -> Self::Fut; +} + +/// Returns a new [`ExtensionFn`] with the given closure. +/// +/// [`ExtensionFn`]: struct.ExtensionFn.html +pub fn extension_fn(f: F) -> ExtensionFn { + ExtensionFn { f } +} + +/// An [`Extension`] implemented by a closure. +/// +/// [`Extension`]: trait.Extension.html +#[derive(Clone, Debug)] +pub struct ExtensionFn { + f: F, +} + +impl Extension for ExtensionFn +where + F: Fn(NextEvent) -> Fut, + Fut: Future>, +{ + type Fut = Fut; + fn call(&mut self, event: NextEvent) -> Self::Fut { + (self.f)(event) + } +} + +/// The Runtime handles all the incoming extension requests +pub struct Runtime = HttpConnector> { + extension_id: String, + client: Client, +} + +impl Runtime { + /// Create a [`RuntimeBuilder`] to initialize the [`Runtime`] + pub fn builder<'a>() -> RuntimeBuilder<'a> { + RuntimeBuilder::default() + } +} + +impl Runtime +where + C: Service + Clone + Send + Sync + Unpin + 'static, + >::Future: Unpin + Send, + >::Error: Into>, + >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, +{ + /// Execute the given extension. + /// Register the extension with the Extensions API and wait for incoming events. + pub async fn run(&self, mut extension: impl Extension) -> Result<(), Error> { + let client = &self.client; + + let incoming = async_stream::stream! { + loop { + trace!("Waiting for next event (incoming loop)"); + let req = requests::next_event_request(&self.extension_id)?; + let res = client.call(req).await; + yield res; + } + }; + + tokio::pin!(incoming); + while let Some(event) = incoming.next().await { + trace!("New event arrived (run loop)"); + let event = event?; + let (_parts, body) = event.into_parts(); + + let body = hyper::body::to_bytes(body).await?; + trace!("{}", std::str::from_utf8(&body)?); // this may be very verbose + let event: NextEvent = serde_json::from_slice(&body)?; + let is_invoke = event.is_invoke(); + + let res = extension.call(event).await; + if let Err(error) = res { + let req = if is_invoke { + requests::init_error(&self.extension_id, &error.to_string(), None)? + } else { + requests::exit_error(&self.extension_id, &error.to_string(), None)? + }; + + self.client.call(req).await?; + return Err(error); + } + } + + Ok(()) + } +} + +/// Builder to construct a new extension [`Runtime`] +#[derive(Default)] +pub struct RuntimeBuilder<'a> { + extension_name: Option<&'a str>, + events: Option<&'a [&'a str]>, +} + +impl<'a> RuntimeBuilder<'a> { + /// Create a new [`RuntimeBuilder`] with a given extension name + pub fn with_extension_name(self, extension_name: &'a str) -> Self { + RuntimeBuilder { + extension_name: Some(extension_name), + ..self + } + } + + /// Create a new [`RuntimeBuilder`] with a list of given events. + /// The only accepted events are `INVOKE` and `SHUTDOWN`. + pub fn with_events(self, events: &'a [&'a str]) -> Self { + RuntimeBuilder { + events: Some(events), + ..self + } + } + + /// Initialize and register the extension in the Extensions API + pub async fn register(&self) -> Result { + let name = match self.extension_name { + Some(name) => name.into(), + None => { + let args: Vec = std::env::args().collect(); + PathBuf::from(args[0].clone()) + .file_name() + .expect("unexpected executable name") + .to_str() + .expect("unexpect executable name") + .to_string() + } + }; + + let events = self.events.unwrap_or(&["INVOKE", "SHUTDOWN"]); + + let client = Client::builder().build()?; + + let req = requests::register_request(&name, events)?; + let res = client.call(req).await?; + if res.status() != http::StatusCode::OK { + return Err(ExtensionError::boxed("unable to register the extension")); + } + + let extension_id = res.headers().get(requests::EXTENSION_ID_HEADER).unwrap().to_str()?; + Ok(Runtime { + extension_id: extension_id.into(), + client, + }) + } +} + +/// Execute the given extension +pub async fn run(extension: Ex) -> Result<(), Error> +where + Ex: Extension, +{ + Runtime::builder().register().await?.run(extension).await +} diff --git a/lambda-extension/src/requests.rs b/lambda-extension/src/requests.rs new file mode 100644 index 00000000..2fdbf2a6 --- /dev/null +++ b/lambda-extension/src/requests.rs @@ -0,0 +1,82 @@ +use crate::Error; +use http::{Method, Request}; +use hyper::Body; +use lambda_runtime_api_client::build_request; +use serde::Serialize; + +const EXTENSION_NAME_HEADER: &str = "Lambda-Extension-Name"; +pub(crate) const EXTENSION_ID_HEADER: &str = "Lambda-Extension-Identifier"; +const EXTENSION_ERROR_TYPE_HEADER: &str = "Lambda-Extension-Function-Error-Type"; + +pub(crate) fn next_event_request(extension_id: &str) -> Result, Error> { + let req = build_request() + .method(Method::GET) + .header(EXTENSION_ID_HEADER, extension_id) + .uri("/2020-01-01/extension/event/next") + .body(Body::empty())?; + Ok(req) +} + +pub(crate) fn register_request(extension_name: &str, events: &[&str]) -> Result, Error> { + let events = serde_json::json!({ "events": events }); + + let req = build_request() + .method(Method::POST) + .uri("/2020-01-01/extension/register") + .header(EXTENSION_NAME_HEADER, extension_name) + .body(Body::from(serde_json::to_string(&events)?))?; + + Ok(req) +} + +/// Payload to send error information to the Extensions API. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ErrorRequest<'a> { + /// Human readable error description + pub error_message: &'a str, + /// The type of error to categorize + pub error_type: &'a str, + /// The error backtrace + pub stack_trace: Vec<&'a str>, +} + +/// Create a new init error request to send to the Extensions API +pub fn init_error<'a>( + extension_id: &str, + error_type: &str, + request: Option>, +) -> Result, Error> { + error_request("init", extension_id, error_type, request) +} + +/// Create a new exit error request to send to the Extensions API +pub fn exit_error<'a>( + extension_id: &str, + error_type: &str, + request: Option>, +) -> Result, Error> { + error_request("exit", extension_id, error_type, request) +} + +fn error_request<'a>( + error_type: &str, + extension_id: &str, + error_str: &str, + request: Option>, +) -> Result, Error> { + let uri = format!("/2020-01-01/extension/{}/error", error_type); + + let body = match request { + None => Body::empty(), + Some(err) => Body::from(serde_json::to_string(&err)?), + }; + + let req = build_request() + .method(Method::POST) + .uri(uri) + .header(EXTENSION_ID_HEADER, extension_id) + .header(EXTENSION_ERROR_TYPE_HEADER, error_str) + .body(body)?; + Ok(req) +} diff --git a/lambda-runtime-api-client/Cargo.toml b/lambda-runtime-api-client/Cargo.toml new file mode 100644 index 00000000..48188a91 --- /dev/null +++ b/lambda-runtime-api-client/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "lambda_runtime_api_client" +version = "0.4.1" +edition = "2021" +authors = ["David Calavera "] +description = "AWS Lambda Runtime interaction API" +license = "Apache-2.0" +repository = "https://github.com/awslabs/aws-lambda-rust-runtime" +categories = ["web-programming::http-server"] +keywords = ["AWS", "Lambda", "API"] +readme = "README.md" + +[dependencies] +http = "0.2" +hyper = { version = "0.14", features = ["http1", "client", "server", "stream", "runtime"] } +tower-service = "0.3" +tokio = { version = "1.0", features = ["io-util"] } + +[dev-dependencies] +serde_json = "^1" +async-stream = "0.3" +tokio-stream = "0.1.2" \ No newline at end of file diff --git a/lambda-runtime-api-client/README.md b/lambda-runtime-api-client/README.md new file mode 100644 index 00000000..530fefdd --- /dev/null +++ b/lambda-runtime-api-client/README.md @@ -0,0 +1,35 @@ +# AWS Lambda Runtime API Client + +[![Docs](https://docs.rs/lambda_runtime_api_client/badge.svg)](https://docs.rs/lambda_runtime_api_client) + +**`lambda-runtime-api-client`** is a library to interact with the AWS Lambda Runtime API. + +This crate provides simple building blocks to send REST request to this API. You probably don't need to use this crate directly, look at [lambda_runtime](https://docs.rs/lambda_runtime) and [lambda_extension](https://docs.rs/lambda_extension) instead. + +## Example + +```rust,no_run +use http::{Method, Request}; +use hyper::Body; +use lambda_runtime_api_client::{build_request, Client, Error}; + +fn register_request(extension_name: &str, events: &[&str]) -> Result, Error> { + let events = serde_json::json!({ "events": events }); + + let req = build_request() + .method(Method::POST) + .uri("/2020-01-01/extension/register") + .header("Lambda-Extension-Name", extension_name) + .body(Body::from(serde_json::to_string(&events)?))?; + + Ok(req) +} + +#[tokio::main] +async fn main() -> Result<(), Error> { + let client = Client::builder().build()?; + let request = register_request("my_extension", &["INVOKE"])?; + + client.call(request).await +} +``` diff --git a/lambda-runtime-api-client/src/lib.rs b/lambda-runtime-api-client/src/lib.rs new file mode 100644 index 00000000..e585944e --- /dev/null +++ b/lambda-runtime-api-client/src/lib.rs @@ -0,0 +1,134 @@ +#![deny(clippy::all, clippy::cargo)] +#![warn(missing_docs, nonstandard_style, rust_2018_idioms)] + +//! This crate includes a base HTTP client to interact with +//! the AWS Lambda Runtime API. +use http::{uri::Scheme, Request, Response, Uri}; +use hyper::{ + client::{connect::Connection, HttpConnector}, + Body, +}; +use std::fmt::Debug; +use tokio::io::{AsyncRead, AsyncWrite}; +use tower_service::Service; + +const USER_AGENT_HEADER: &str = "User-Agent"; +const USER_AGENT: &str = concat!("aws-lambda-rust/", env!("CARGO_PKG_VERSION")); + +/// Error type that lambdas may result in +pub type Error = Box; + +/// API client to interact with the AWS Lambda Runtime API. +#[derive(Debug)] +pub struct Client { + /// The runtime API URI + pub base: Uri, + /// The client that manages the API connections + pub client: hyper::Client, +} + +impl Client { + /// Create a builder struct to configure the client. + pub fn builder() -> ClientBuilder { + ClientBuilder { + connector: HttpConnector::new(), + uri: None, + } + } +} + +impl Client +where + C: hyper::client::connect::Connect + Sync + Send + Clone + 'static, +{ + /// Send a given request to the Runtime API. + /// Use the client's base URI to ensure the API endpoint is correct. + pub async fn call(&self, req: Request) -> Result, Error> { + let req = self.set_origin(req)?; + let response = self.client.request(req).await?; + Ok(response) + } + + /// Create a new client with a given base URI and HTTP connector. + pub fn with(base: Uri, connector: C) -> Self { + let client = hyper::Client::builder().build(connector); + Self { base, client } + } + + fn set_origin(&self, req: Request) -> Result, Error> { + let (mut parts, body) = req.into_parts(); + let (scheme, authority) = { + let scheme = self.base.scheme().unwrap_or(&Scheme::HTTP); + let authority = self.base.authority().expect("Authority not found"); + (scheme, authority) + }; + let path = parts.uri.path_and_query().expect("PathAndQuery not found"); + + let uri = Uri::builder() + .scheme(scheme.clone()) + .authority(authority.clone()) + .path_and_query(path.clone()) + .build(); + + match uri { + Ok(u) => { + parts.uri = u; + Ok(Request::from_parts(parts, body)) + } + Err(e) => Err(Box::new(e)), + } + } +} + +/// Builder implementation to construct any Runtime API clients. +pub struct ClientBuilder = hyper::client::HttpConnector> { + connector: C, + uri: Option, +} + +impl ClientBuilder +where + C: Service + Clone + Send + Sync + Unpin + 'static, + >::Future: Unpin + Send, + >::Error: Into>, + >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, +{ + /// Create a new builder with a given HTTP connector. + pub fn with_connector(self, connector: C2) -> ClientBuilder + where + C2: Service + Clone + Send + Sync + Unpin + 'static, + >::Future: Unpin + Send, + >::Error: Into>, + >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + { + ClientBuilder { + connector, + uri: self.uri, + } + } + + /// Create a new builder with a given base URI. + /// Inherits all other attributes from the existent builder. + pub fn with_endpoint(self, uri: http::Uri) -> Self { + Self { uri: Some(uri), ..self } + } + + /// Create the new client to interact with the Runtime API. + pub fn build(self) -> Result, Error> { + let uri = match self.uri { + Some(uri) => uri, + None => { + let uri = std::env::var("AWS_LAMBDA_RUNTIME_API").expect("Missing AWS_LAMBDA_RUNTIME_API env var"); + uri.try_into().expect("Unable to convert to URL") + } + }; + Ok(Client::with(uri, self.connector)) + } +} + +/// Create a request builder. +/// This builder uses `aws-lambda-rust/CRATE_VERSION` as +/// the default User-Agent. +pub fn build_request() -> http::request::Builder { + http::Request::builder().header(USER_AGENT_HEADER, USER_AGENT) +} diff --git a/lambda-runtime/Cargo.toml b/lambda-runtime/Cargo.toml index 4d0de675..25bc26ec 100644 --- a/lambda-runtime/Cargo.toml +++ b/lambda-runtime/Cargo.toml @@ -27,6 +27,7 @@ tracing-error = "0.2" tracing = { version = "0.1", features = ["log"] } tower-service = "0.3" tokio-stream = "0.1.2" +lambda_runtime_api_client = { version = "0.4", path = "../lambda-runtime-api-client" } [dev-dependencies] tracing-subscriber = "0.3" diff --git a/lambda-runtime/src/client.rs b/lambda-runtime/src/client.rs deleted file mode 100644 index 5e39e300..00000000 --- a/lambda-runtime/src/client.rs +++ /dev/null @@ -1,308 +0,0 @@ -use crate::Error; -use http::{uri::Scheme, Request, Response, Uri}; -use hyper::{client::HttpConnector, Body}; -use std::fmt::Debug; - -#[derive(Debug)] -pub(crate) struct Client { - pub(crate) base: Uri, - pub(crate) client: hyper::Client, -} - -impl Client -where - C: hyper::client::connect::Connect + Sync + Send + Clone + 'static, -{ - pub fn with(base: Uri, connector: C) -> Self { - let client = hyper::Client::builder().build(connector); - Self { base, client } - } - - fn set_origin(&self, req: Request) -> Result, Error> { - let (mut parts, body) = req.into_parts(); - let (scheme, authority) = { - let scheme = self.base.scheme().unwrap_or(&Scheme::HTTP); - let authority = self.base.authority().expect("Authority not found"); - (scheme, authority) - }; - let path = parts.uri.path_and_query().expect("PathAndQuery not found"); - - let uri = Uri::builder() - .scheme(scheme.clone()) - .authority(authority.clone()) - .path_and_query(path.clone()) - .build(); - - match uri { - Ok(u) => { - parts.uri = u; - Ok(Request::from_parts(parts, body)) - } - Err(e) => Err(Box::new(e)), - } - } - - pub(crate) async fn call(&self, req: Request) -> Result, Error> { - let req = self.set_origin(req)?; - let response = self.client.request(req).await?; - Ok(response) - } -} - -#[cfg(test)] -mod endpoint_tests { - use crate::{ - client::Client, - incoming, - requests::{ - EventCompletionRequest, EventErrorRequest, IntoRequest, IntoResponse, NextEventRequest, NextEventResponse, - }, - simulated, - types::Diagnostic, - Error, Runtime, - }; - use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri}; - use hyper::{server::conn::Http, service::service_fn, Body}; - use serde_json::json; - use simulated::DuplexStreamWrapper; - use std::{convert::TryFrom, env}; - use tokio::{ - io::{self, AsyncRead, AsyncWrite}, - select, - sync::{self, oneshot}, - }; - use tokio_stream::StreamExt; - - #[cfg(test)] - async fn next_event(req: &Request) -> Result, Error> { - let path = "/2018-06-01/runtime/invocation/next"; - assert_eq!(req.method(), Method::GET); - assert_eq!(req.uri().path_and_query().unwrap(), &PathAndQuery::from_static(path)); - let body = json!({"message": "hello"}); - - let rsp = NextEventResponse { - request_id: "8476a536-e9f4-11e8-9739-2dfe598c3fcd", - deadline: 1_542_409_706_888, - arn: "arn:aws:lambda:us-east-2:123456789012:function:custom-runtime", - trace_id: "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419", - body: serde_json::to_vec(&body)?, - }; - rsp.into_rsp() - } - - #[cfg(test)] - async fn complete_event(req: &Request, id: &str) -> Result, Error> { - assert_eq!(Method::POST, req.method()); - let rsp = Response::builder() - .status(StatusCode::ACCEPTED) - .body(Body::empty()) - .expect("Unable to construct response"); - - let expected = format!("/2018-06-01/runtime/invocation/{}/response", id); - assert_eq!(expected, req.uri().path()); - - Ok(rsp) - } - - #[cfg(test)] - async fn event_err(req: &Request, id: &str) -> Result, Error> { - let expected = format!("/2018-06-01/runtime/invocation/{}/error", id); - assert_eq!(expected, req.uri().path()); - - assert_eq!(req.method(), Method::POST); - let header = "lambda-runtime-function-error-type"; - let expected = "unhandled"; - assert_eq!(req.headers()[header], HeaderValue::try_from(expected)?); - - let rsp = Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?; - Ok(rsp) - } - - #[cfg(test)] - async fn handle_incoming(req: Request) -> Result, Error> { - let path: Vec<&str> = req - .uri() - .path_and_query() - .expect("PathAndQuery not found") - .as_str() - .split('/') - .collect::>(); - match path[1..] { - ["2018-06-01", "runtime", "invocation", "next"] => next_event(&req).await, - ["2018-06-01", "runtime", "invocation", id, "response"] => complete_event(&req, id).await, - ["2018-06-01", "runtime", "invocation", id, "error"] => event_err(&req, id).await, - ["2018-06-01", "runtime", "init", "error"] => unimplemented!(), - _ => unimplemented!(), - } - } - - #[cfg(test)] - async fn handle(io: I, rx: oneshot::Receiver<()>) -> Result<(), hyper::Error> - where - I: AsyncRead + AsyncWrite + Unpin + 'static, - { - let conn = Http::new().serve_connection(io, service_fn(handle_incoming)); - select! { - _ = rx => { - Ok(()) - } - res = conn => { - match res { - Ok(()) => Ok(()), - Err(e) => { - Err(e) - } - } - } - } - } - - #[tokio::test] - async fn test_next_event() -> Result<(), Error> { - let base = Uri::from_static("http://localhost:9001"); - let (client, server) = io::duplex(64); - - let (tx, rx) = sync::oneshot::channel(); - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); - }); - - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::with(base, conn); - - let req = NextEventRequest.into_req()?; - let rsp = client.call(req).await.expect("Unable to send request"); - - assert_eq!(rsp.status(), StatusCode::OK); - let header = "lambda-runtime-deadline-ms"; - assert_eq!(rsp.headers()[header], &HeaderValue::try_from("1542409706888")?); - - // shutdown server... - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } - } - - #[tokio::test] - async fn test_ok_response() -> Result<(), Error> { - let (client, server) = io::duplex(64); - let (tx, rx) = sync::oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); - }); - - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::with(base, conn); - - let req = EventCompletionRequest { - request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", - body: "done", - }; - let req = req.into_req()?; - - let rsp = client.call(req).await?; - assert_eq!(rsp.status(), StatusCode::ACCEPTED); - - // shutdown server - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } - } - - #[tokio::test] - async fn test_error_response() -> Result<(), Error> { - let (client, server) = io::duplex(200); - let (tx, rx) = sync::oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); - }); - - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - let client = Client::with(base, conn); - - let req = EventErrorRequest { - request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", - diagnostic: Diagnostic { - error_type: "InvalidEventDataError".to_string(), - error_message: "Error parsing event data".to_string(), - }, - }; - let req = req.into_req()?; - let rsp = client.call(req).await?; - assert_eq!(rsp.status(), StatusCode::ACCEPTED); - - // shutdown server - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } - } - - #[tokio::test] - async fn successful_end_to_end_run() -> Result<(), Error> { - let (client, server) = io::duplex(64); - let (tx, rx) = sync::oneshot::channel(); - let base = Uri::from_static("http://localhost:9001"); - - let server = tokio::spawn(async { - handle(server, rx).await.expect("Unable to handle request"); - }); - let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; - - let runtime = Runtime::builder() - .with_endpoint(base) - .with_connector(conn) - .build() - .expect("Unable to build runtime"); - - async fn func(event: serde_json::Value, _: crate::Context) -> Result { - Ok(event) - } - let f = crate::handler_fn(func); - - // set env vars needed to init Config if they are not already set in the environment - if env::var("AWS_LAMBDA_RUNTIME_API").is_err() { - env::set_var("AWS_LAMBDA_RUNTIME_API", "http://localhost:9001"); - } - if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() { - env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn"); - } - if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() { - env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); - } - if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() { - env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1"); - } - if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() { - env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream"); - } - if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() { - env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); - } - let config = crate::Config::from_env().expect("Failed to read env vars"); - - let client = &runtime.client; - let incoming = incoming(client).take(1); - runtime.run(incoming, f, &config).await?; - - // shutdown server - tx.send(()).expect("Receiver has been dropped"); - match server.await { - Ok(_) => Ok(()), - Err(e) if e.is_panic() => Err::<(), Error>(e.into()), - Err(_) => unreachable!("This branch shouldn't be reachable"), - } - } -} diff --git a/lambda-runtime/src/lib.rs b/lambda-runtime/src/lib.rs index 30220e45..85403ea3 100644 --- a/lambda-runtime/src/lib.rs +++ b/lambda-runtime/src/lib.rs @@ -6,21 +6,15 @@ //! Create a type that conforms to the [`Handler`] trait. This type can then be passed //! to the the `lambda_runtime::run` function, which launches and runs the Lambda runtime. pub use crate::types::Context; -use client::Client; use hyper::client::{connect::Connection, HttpConnector}; +use lambda_runtime_api_client::Client; use serde::{Deserialize, Serialize}; -use std::{ - convert::{TryFrom, TryInto}, - env, fmt, - future::Future, - panic, -}; +use std::{convert::TryFrom, env, fmt, future::Future, panic}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::{Stream, StreamExt}; use tower_service::Service; use tracing::{error, trace}; -mod client; mod requests; #[cfg(test)] mod simulated; @@ -31,13 +25,11 @@ use requests::{EventCompletionRequest, EventErrorRequest, IntoRequest, NextEvent use types::Diagnostic; /// Error type that lambdas may result in -pub type Error = Box; +pub type Error = lambda_runtime_api_client::Error; /// Configuration derived from environment variables. #[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)] pub struct Config { - /// The host and port of the [runtime API](https://docs.aws.amazon.com/lambda/latest/dg/runtimes-api.html). - pub endpoint: String, /// The name of the function. pub function_name: String, /// The amount of memory available to the function in MB. @@ -54,7 +46,6 @@ impl Config { /// Attempts to read configuration from environment variables. pub fn from_env() -> Result { let conf = Config { - endpoint: env::var("AWS_LAMBDA_RUNTIME_API").expect("Missing AWS_LAMBDA_RUNTIME_API env var"), function_name: env::var("AWS_LAMBDA_FUNCTION_NAME").expect("Missing AWS_LAMBDA_FUNCTION_NAME env var"), memory: env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") .expect("Missing AWS_LAMBDA_FUNCTION_MEMORY_SIZE env var") @@ -106,25 +97,10 @@ where } } -#[non_exhaustive] -#[derive(Debug, PartialEq)] -enum BuilderError { - UnsetUri, -} - struct Runtime = HttpConnector> { client: Client, } -impl Runtime { - pub fn builder() -> RuntimeBuilder { - RuntimeBuilder { - connector: HttpConnector::new(), - uri: None, - } - } -} - impl Runtime where C: Service + Clone + Send + Sync + Unpin + 'static, @@ -209,56 +185,6 @@ where } } -struct RuntimeBuilder = hyper::client::HttpConnector> { - connector: C, - uri: Option, -} - -impl RuntimeBuilder -where - C: Service + Clone + Send + Sync + Unpin + 'static, - >::Future: Unpin + Send, - >::Error: Into>, - >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, -{ - pub fn with_connector(self, connector: C2) -> RuntimeBuilder - where - C2: Service + Clone + Send + Sync + Unpin + 'static, - >::Future: Unpin + Send, - >::Error: Into>, - >::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, - { - RuntimeBuilder { - connector, - uri: self.uri, - } - } - - pub fn with_endpoint(self, uri: http::Uri) -> Self { - Self { uri: Some(uri), ..self } - } - - pub fn build(self) -> Result, BuilderError> { - let uri = match self.uri { - Some(uri) => uri, - None => return Err(BuilderError::UnsetUri), - }; - let client = Client::with(uri, self.connector); - - Ok(Runtime { client }) - } -} - -#[test] -fn test_builder() { - let runtime = Runtime::builder() - .with_connector(HttpConnector::new()) - .with_endpoint(http::Uri::from_static("http://nomatter.com")) - .build(); - - runtime.unwrap(); -} - fn incoming(client: &Client) -> impl Stream, Error>> + Send + '_ where C: Service + Clone + Send + Sync + Unpin + 'static, @@ -305,12 +231,8 @@ where { trace!("Loading config from env"); let config = Config::from_env()?; - let uri = config.endpoint.clone().try_into().expect("Unable to convert to URL"); - let runtime = Runtime::builder() - .with_connector(HttpConnector::new()) - .with_endpoint(uri) - .build() - .expect("Unable to create a runtime"); + let client = Client::builder().build().expect("Unable to create a runtime client"); + let runtime = Runtime { client }; let client = &runtime.client; let incoming = incoming(client); @@ -320,3 +242,262 @@ where fn type_name_of_val(_: T) -> &'static str { std::any::type_name::() } + +#[cfg(test)] +mod endpoint_tests { + use crate::{ + incoming, + requests::{ + EventCompletionRequest, EventErrorRequest, IntoRequest, IntoResponse, NextEventRequest, NextEventResponse, + }, + simulated, + types::Diagnostic, + Error, Runtime, + }; + use http::{uri::PathAndQuery, HeaderValue, Method, Request, Response, StatusCode, Uri}; + use hyper::{server::conn::Http, service::service_fn, Body}; + use lambda_runtime_api_client::Client; + use serde_json::json; + use simulated::DuplexStreamWrapper; + use std::{convert::TryFrom, env}; + use tokio::{ + io::{self, AsyncRead, AsyncWrite}, + select, + sync::{self, oneshot}, + }; + use tokio_stream::StreamExt; + + #[cfg(test)] + async fn next_event(req: &Request) -> Result, Error> { + let path = "/2018-06-01/runtime/invocation/next"; + assert_eq!(req.method(), Method::GET); + assert_eq!(req.uri().path_and_query().unwrap(), &PathAndQuery::from_static(path)); + let body = json!({"message": "hello"}); + + let rsp = NextEventResponse { + request_id: "8476a536-e9f4-11e8-9739-2dfe598c3fcd", + deadline: 1_542_409_706_888, + arn: "arn:aws:lambda:us-east-2:123456789012:function:custom-runtime", + trace_id: "Root=1-5bef4de7-ad49b0e87f6ef6c87fc2e700;Parent=9a9197af755a6419", + body: serde_json::to_vec(&body)?, + }; + rsp.into_rsp() + } + + #[cfg(test)] + async fn complete_event(req: &Request, id: &str) -> Result, Error> { + assert_eq!(Method::POST, req.method()); + let rsp = Response::builder() + .status(StatusCode::ACCEPTED) + .body(Body::empty()) + .expect("Unable to construct response"); + + let expected = format!("/2018-06-01/runtime/invocation/{}/response", id); + assert_eq!(expected, req.uri().path()); + + Ok(rsp) + } + + #[cfg(test)] + async fn event_err(req: &Request, id: &str) -> Result, Error> { + let expected = format!("/2018-06-01/runtime/invocation/{}/error", id); + assert_eq!(expected, req.uri().path()); + + assert_eq!(req.method(), Method::POST); + let header = "lambda-runtime-function-error-type"; + let expected = "unhandled"; + assert_eq!(req.headers()[header], HeaderValue::try_from(expected)?); + + let rsp = Response::builder().status(StatusCode::ACCEPTED).body(Body::empty())?; + Ok(rsp) + } + + #[cfg(test)] + async fn handle_incoming(req: Request) -> Result, Error> { + let path: Vec<&str> = req + .uri() + .path_and_query() + .expect("PathAndQuery not found") + .as_str() + .split('/') + .collect::>(); + match path[1..] { + ["2018-06-01", "runtime", "invocation", "next"] => next_event(&req).await, + ["2018-06-01", "runtime", "invocation", id, "response"] => complete_event(&req, id).await, + ["2018-06-01", "runtime", "invocation", id, "error"] => event_err(&req, id).await, + ["2018-06-01", "runtime", "init", "error"] => unimplemented!(), + _ => unimplemented!(), + } + } + + #[cfg(test)] + async fn handle(io: I, rx: oneshot::Receiver<()>) -> Result<(), hyper::Error> + where + I: AsyncRead + AsyncWrite + Unpin + 'static, + { + let conn = Http::new().serve_connection(io, service_fn(handle_incoming)); + select! { + _ = rx => { + Ok(()) + } + res = conn => { + match res { + Ok(()) => Ok(()), + Err(e) => { + Err(e) + } + } + } + } + } + + #[tokio::test] + async fn test_next_event() -> Result<(), Error> { + let base = Uri::from_static("http://localhost:9001"); + let (client, server) = io::duplex(64); + + let (tx, rx) = sync::oneshot::channel(); + let server = tokio::spawn(async { + handle(server, rx).await.expect("Unable to handle request"); + }); + + let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; + let client = Client::with(base, conn); + + let req = NextEventRequest.into_req()?; + let rsp = client.call(req).await.expect("Unable to send request"); + + assert_eq!(rsp.status(), StatusCode::OK); + let header = "lambda-runtime-deadline-ms"; + assert_eq!(rsp.headers()[header], &HeaderValue::try_from("1542409706888")?); + + // shutdown server... + tx.send(()).expect("Receiver has been dropped"); + match server.await { + Ok(_) => Ok(()), + Err(e) if e.is_panic() => Err::<(), Error>(e.into()), + Err(_) => unreachable!("This branch shouldn't be reachable"), + } + } + + #[tokio::test] + async fn test_ok_response() -> Result<(), Error> { + let (client, server) = io::duplex(64); + let (tx, rx) = sync::oneshot::channel(); + let base = Uri::from_static("http://localhost:9001"); + + let server = tokio::spawn(async { + handle(server, rx).await.expect("Unable to handle request"); + }); + + let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; + let client = Client::with(base, conn); + + let req = EventCompletionRequest { + request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", + body: "done", + }; + let req = req.into_req()?; + + let rsp = client.call(req).await?; + assert_eq!(rsp.status(), StatusCode::ACCEPTED); + + // shutdown server + tx.send(()).expect("Receiver has been dropped"); + match server.await { + Ok(_) => Ok(()), + Err(e) if e.is_panic() => Err::<(), Error>(e.into()), + Err(_) => unreachable!("This branch shouldn't be reachable"), + } + } + + #[tokio::test] + async fn test_error_response() -> Result<(), Error> { + let (client, server) = io::duplex(200); + let (tx, rx) = sync::oneshot::channel(); + let base = Uri::from_static("http://localhost:9001"); + + let server = tokio::spawn(async { + handle(server, rx).await.expect("Unable to handle request"); + }); + + let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; + let client = Client::with(base, conn); + + let req = EventErrorRequest { + request_id: "156cb537-e2d4-11e8-9b34-d36013741fb9", + diagnostic: Diagnostic { + error_type: "InvalidEventDataError".to_string(), + error_message: "Error parsing event data".to_string(), + }, + }; + let req = req.into_req()?; + let rsp = client.call(req).await?; + assert_eq!(rsp.status(), StatusCode::ACCEPTED); + + // shutdown server + tx.send(()).expect("Receiver has been dropped"); + match server.await { + Ok(_) => Ok(()), + Err(e) if e.is_panic() => Err::<(), Error>(e.into()), + Err(_) => unreachable!("This branch shouldn't be reachable"), + } + } + + #[tokio::test] + async fn successful_end_to_end_run() -> Result<(), Error> { + let (client, server) = io::duplex(64); + let (tx, rx) = sync::oneshot::channel(); + let base = Uri::from_static("http://localhost:9001"); + + let server = tokio::spawn(async { + handle(server, rx).await.expect("Unable to handle request"); + }); + let conn = simulated::Connector::with(base.clone(), DuplexStreamWrapper::new(client))?; + + let client = Client::builder() + .with_endpoint(base) + .with_connector(conn) + .build() + .expect("Unable to build client"); + + async fn func(event: serde_json::Value, _: crate::Context) -> Result { + Ok(event) + } + let f = crate::handler_fn(func); + + // set env vars needed to init Config if they are not already set in the environment + if env::var("AWS_LAMBDA_RUNTIME_API").is_err() { + env::set_var("AWS_LAMBDA_RUNTIME_API", "http://localhost:9001"); + } + if env::var("AWS_LAMBDA_FUNCTION_NAME").is_err() { + env::set_var("AWS_LAMBDA_FUNCTION_NAME", "test_fn"); + } + if env::var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE").is_err() { + env::set_var("AWS_LAMBDA_FUNCTION_MEMORY_SIZE", "128"); + } + if env::var("AWS_LAMBDA_FUNCTION_VERSION").is_err() { + env::set_var("AWS_LAMBDA_FUNCTION_VERSION", "1"); + } + if env::var("AWS_LAMBDA_LOG_STREAM_NAME").is_err() { + env::set_var("AWS_LAMBDA_LOG_STREAM_NAME", "test_stream"); + } + if env::var("AWS_LAMBDA_LOG_GROUP_NAME").is_err() { + env::set_var("AWS_LAMBDA_LOG_GROUP_NAME", "test_log"); + } + let config = crate::Config::from_env().expect("Failed to read env vars"); + + let runtime = Runtime { client }; + let client = &runtime.client; + let incoming = incoming(client).take(1); + runtime.run(incoming, f, &config).await?; + + // shutdown server + tx.send(()).expect("Receiver has been dropped"); + match server.await { + Ok(_) => Ok(()), + Err(e) if e.is_panic() => Err::<(), Error>(e.into()), + Err(_) => unreachable!("This branch shouldn't be reachable"), + } + } +} diff --git a/lambda-runtime/src/requests.rs b/lambda-runtime/src/requests.rs index 8aa2edbe..4d033614 100644 --- a/lambda-runtime/src/requests.rs +++ b/lambda-runtime/src/requests.rs @@ -1,11 +1,10 @@ use crate::{types::Diagnostic, Error}; use http::{Method, Request, Response, Uri}; use hyper::Body; +use lambda_runtime_api_client::build_request; use serde::Serialize; use std::str::FromStr; -const USER_AGENT: &str = concat!("aws-lambda-rust/", env!("CARGO_PKG_VERSION")); - pub(crate) trait IntoRequest { fn into_req(self) -> Result, Error>; } @@ -20,9 +19,8 @@ pub(crate) struct NextEventRequest; impl IntoRequest for NextEventRequest { fn into_req(self) -> Result, Error> { - let req = Request::builder() + let req = build_request() .method(Method::GET) - .header("User-Agent", USER_AGENT) .uri(Uri::from_static("/2018-06-01/runtime/invocation/next")) .body(Body::empty())?; Ok(req) @@ -82,11 +80,7 @@ where let body = serde_json::to_vec(&self.body)?; let body = Body::from(body); - let req = Request::builder() - .header("User-Agent", USER_AGENT) - .method(Method::POST) - .uri(uri) - .body(body)?; + let req = build_request().method(Method::POST).uri(uri).body(body)?; Ok(req) } } @@ -120,10 +114,9 @@ impl<'a> IntoRequest for EventErrorRequest<'a> { let body = serde_json::to_vec(&self.diagnostic)?; let body = Body::from(body); - let req = Request::builder() + let req = build_request() .method(Method::POST) .uri(uri) - .header("User-Agent", USER_AGENT) .header("lambda-runtime-function-error-type", "unhandled") .body(body)?; Ok(req) @@ -157,10 +150,9 @@ impl IntoRequest for InitErrorRequest { let uri = "/2018-06-01/runtime/init/error".to_string(); let uri = Uri::from_str(&uri)?; - let req = Request::builder() + let req = build_request() .method(Method::POST) .uri(uri) - .header("User-Agent", USER_AGENT) .header("lambda-runtime-function-error-type", "unhandled") .body(Body::empty())?; Ok(req)