From 8148161e7902a362eacc4341dd95de11c9d5bb98 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 05:56:53 +0000 Subject: [PATCH 01/31] chore: Add serial_test dependency and update config tests to run serially --- Cargo.lock | 41 +++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 1 + src/config_tests.rs | 6 ++++++ 3 files changed, 48 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 313fd47..8159ddb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1229,6 +1229,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "serial_test", "tempfile", "tokio", "tokio-stream", @@ -1698,6 +1699,15 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +[[package]] +name = "scc" +version = "2.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1723,6 +1733,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sdd" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49c1eeaf4b6a87c7479688c6d52b9f1153cedd3c489300564f932b065c6eab95" + [[package]] name = "security-framework" version = "2.3.1" @@ -1818,6 +1834,31 @@ dependencies = [ "unsafe-libyaml", ] +[[package]] +name = "serial_test" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "sha2" version = "0.10.8" diff --git a/Cargo.toml b/Cargo.toml index cc1bdc4..84dd787 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ http-serde = "2.1.1" [dev-dependencies] tempfile = "3.8.1" +serial_test = "3" [[bin]] name = "lambda-web-gateway" diff --git a/src/config_tests.rs b/src/config_tests.rs index 02bfc6e..373da6f 100644 --- a/src/config_tests.rs +++ b/src/config_tests.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use std::env; use tempfile::NamedTempFile; use std::io::Write; +use serial_test::serial; #[test] fn test_auth_mode_from_str() { @@ -40,6 +41,7 @@ fn test_config_panic_on_empty_lambda_function_name() { } #[test] +#[serial] fn test_config_apply_env_overrides() { env::set_var("LAMBDA_FUNCTION_NAME", "test-function"); env::set_var("LAMBDA_INVOKE_MODE", "responsestream"); @@ -89,6 +91,7 @@ addr: 127.0.0.1:3000 } #[test] +#[serial] fn test_config_load_with_env_override() { let config_content = r#" lambda_function_name: file-function @@ -133,6 +136,7 @@ addr: 0.0.0.0:8000 } #[test] +#[serial] fn test_config_load_invalid_file() { env::set_var("LAMBDA_FUNCTION_NAME", "env-function"); env::set_var("AUTH_MODE", "apikey"); @@ -153,6 +157,7 @@ fn test_config_load_invalid_file() { } #[test] +#[serial] fn test_config_load_invalid_yaml() { let config_content = "invalid: yaml: content"; @@ -178,6 +183,7 @@ fn test_config_load_invalid_yaml() { } #[test] +#[serial] fn test_config_load_empty_api_keys() { env::set_var("API_KEYS", ""); env::set_var("LAMBDA_FUNCTION_NAME", "test-function"); // Add this line From 1c78815ec54781e86b6861f5c548afed08be1ce8 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 05:57:48 +0000 Subject: [PATCH 02/31] chore: Simplify default values for Config struct and enums --- src/config.rs | 36 +++++++++--------------------------- 1 file changed, 9 insertions(+), 27 deletions(-) diff --git a/src/config.rs b/src/config.rs index 41f05c9..9340cd8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,17 +1,17 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; -use std::str::FromStr; use std::fs; use std::path::Path; +use std::str::FromStr; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Config { pub lambda_function_name: String, - #[serde(default = "default_lambda_invoke_mode")] + #[serde(default)] pub lambda_invoke_mode: LambdaInvokeMode, #[serde(default)] pub api_keys: HashSet, - #[serde(default = "default_auth_mode")] + #[serde(default)] pub auth_mode: AuthMode, #[serde(default = "default_addr")] pub addr: String, @@ -21,9 +21,9 @@ impl Default for Config { fn default() -> Self { Self { lambda_function_name: String::new(), - lambda_invoke_mode: default_lambda_invoke_mode(), + lambda_invoke_mode: Default::default(), api_keys: HashSet::new(), - auth_mode: default_auth_mode(), + auth_mode: Default::default(), addr: default_addr(), } } @@ -76,42 +76,24 @@ mod tests { include!("config_tests.rs"); } -fn default_auth_mode() -> AuthMode { - AuthMode::Open -} - -fn default_lambda_invoke_mode() -> LambdaInvokeMode { - LambdaInvokeMode::Buffered -} - fn default_addr() -> String { "0.0.0.0:8000".to_string() } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)] pub enum AuthMode { + #[default] Open, ApiKey, } -impl Default for AuthMode { - fn default() -> Self { - AuthMode::Open - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)] pub enum LambdaInvokeMode { + #[default] Buffered, ResponseStream, } -impl Default for LambdaInvokeMode { - fn default() -> Self { - LambdaInvokeMode::Buffered - } -} - impl FromStr for AuthMode { type Err = String; From 96a63683929911195e67ae7e5ac737942c672358 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 06:01:25 +0000 Subject: [PATCH 03/31] chore: Format code --- Cargo.toml | 8 +++---- src/config_tests.rs | 54 +++++++++++++++++++++++++++++++++++---------- src/lib_tests.rs | 9 ++------ 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 84dd787..c5ebae2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,14 +17,14 @@ serde_yaml = "0.9" clap = { version = "4.0", features = ["derive"] } serde_json = "1" url = "2.5.0" -axum ={ version = "0.7.5"} +axum = { version = "0.7.5" } aws-config = { version = "1.5.5" } aws-sdk-lambda = { version = "1.42.0" } -aws-smithy-types = { version="1.2.2", features = ["serde-serialize"] } +aws-smithy-types = { version = "1.2.2", features = ["serde-serialize"] } tokio = { version = "1.39.3", features = ["full"] } tower-http = { version = "0.5.2", features = ["trace"] } -tracing-subscriber = { version= "0.3.18", features = ["json"]} -tracing ={ version = "0.1.40"} +tracing-subscriber = { version = "0.3.18", features = ["json"] } +tracing = { version = "0.1.40" } tokio-stream = "0.1.15" futures-util = "0.3.30" http-serde = "2.1.1" diff --git a/src/config_tests.rs b/src/config_tests.rs index 373da6f..7d5b207 100644 --- a/src/config_tests.rs +++ b/src/config_tests.rs @@ -1,9 +1,9 @@ use super::*; +use serial_test::serial; use std::collections::HashSet; use std::env; -use tempfile::NamedTempFile; use std::io::Write; -use serial_test::serial; +use tempfile::NamedTempFile; #[test] fn test_auth_mode_from_str() { @@ -16,10 +16,22 @@ fn test_auth_mode_from_str() { #[test] fn test_lambda_invoke_mode_from_str() { - assert_eq!("buffered".parse::().unwrap(), LambdaInvokeMode::Buffered); - assert_eq!("responsestream".parse::().unwrap(), LambdaInvokeMode::ResponseStream); - assert_eq!("BUFFERED".parse::().unwrap(), LambdaInvokeMode::Buffered); - assert_eq!("RESPONSESTREAM".parse::().unwrap(), LambdaInvokeMode::ResponseStream); + assert_eq!( + "buffered".parse::().unwrap(), + LambdaInvokeMode::Buffered + ); + assert_eq!( + "responsestream".parse::().unwrap(), + LambdaInvokeMode::ResponseStream + ); + assert_eq!( + "BUFFERED".parse::().unwrap(), + LambdaInvokeMode::Buffered + ); + assert_eq!( + "RESPONSESTREAM".parse::().unwrap(), + LambdaInvokeMode::ResponseStream + ); assert!("invalid".parse::().is_err()); } @@ -54,7 +66,13 @@ fn test_config_apply_env_overrides() { assert_eq!(config.lambda_function_name, "test-function"); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); - assert_eq!(config.api_keys, vec!["key1", "key2"].into_iter().map(String::from).collect::>()); + assert_eq!( + config.api_keys, + vec!["key1", "key2"] + .into_iter() + .map(String::from) + .collect::>() + ); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.addr, "127.0.0.1:3000"); @@ -85,7 +103,13 @@ addr: 127.0.0.1:3000 assert_eq!(config.lambda_function_name, "test-function"); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); - assert_eq!(config.api_keys, vec!["key1", "key2"].into_iter().map(String::from).collect::>()); + assert_eq!( + config.api_keys, + vec!["key1", "key2"] + .into_iter() + .map(String::from) + .collect::>() + ); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.addr, "127.0.0.1:3000"); } @@ -114,7 +138,13 @@ addr: 0.0.0.0:8000 assert_eq!(config.lambda_function_name, "env-function"); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); - assert_eq!(config.api_keys, vec!["file-key"].into_iter().map(String::from).collect::>()); + assert_eq!( + config.api_keys, + vec!["file-key"] + .into_iter() + .map(String::from) + .collect::>() + ); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.addr, "0.0.0.0:8000"); @@ -143,7 +173,7 @@ fn test_config_load_invalid_file() { env::set_var("LAMBDA_INVOKE_MODE", "responsestream"); let config = Config::load("non_existent_file.yaml"); - + assert_eq!(config.lambda_function_name, "env-function"); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); @@ -187,9 +217,9 @@ fn test_config_load_invalid_yaml() { fn test_config_load_empty_api_keys() { env::set_var("API_KEYS", ""); env::set_var("LAMBDA_FUNCTION_NAME", "test-function"); // Add this line - + let config = Config::load("non_existent_file.yaml"); - + assert!(config.api_keys.is_empty()); env::remove_var("API_KEYS"); diff --git a/src/lib_tests.rs b/src/lib_tests.rs index b86ce9d..ca070f0 100644 --- a/src/lib_tests.rs +++ b/src/lib_tests.rs @@ -32,9 +32,7 @@ async fn test_handle_buffered_response() { status_code: 200, status_description: Some("OK".to_string()), is_base64_encoded: Some(false), - headers: Some(HashMap::from([ - ("Content-Type".to_string(), "text/plain".to_string()), - ])), + headers: Some(HashMap::from([("Content-Type".to_string(), "text/plain".to_string())])), body: "Hello, World!".to_string(), }; @@ -47,10 +45,7 @@ async fn test_handle_buffered_response() { let response = handle_buffered_response(invoke_output).await; assert_eq!(response.status(), StatusCode::OK); - assert_eq!( - response.headers().get("Content-Type").unwrap(), - "text/plain" - ); + assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain"); let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); assert_eq!(body, "Hello, World!"); } From 35d84530346ca14655725b619decff5211d81964 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 06:03:07 +0000 Subject: [PATCH 04/31] chore: simplify code following clippy --- src/lib.rs | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index fc907d4..636b16d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,7 +78,7 @@ async fn handler( let content_type = headers .get("content-type") - .and_then(|v| v.to_str().ok().map(Some).flatten()) + .and_then(|v| v.to_str().ok()) .unwrap_or_default(); let is_base64_encoded = match content_type { @@ -132,7 +132,7 @@ async fn handler( }) .to_string(); - let resp = match config.lambda_invoke_mode { + match config.lambda_invoke_mode { LambdaInvokeMode::Buffered => { let resp = client .invoke() @@ -154,9 +154,7 @@ async fn handler( .unwrap(); handle_streaming_response(resp).await } - }; - - resp + } } fn to_string_map(headers: &HeaderMap) -> HashMap { @@ -275,7 +273,7 @@ async fn handle_streaming_response( } else { Ok(Bytes::default()) } - }, + } InvokeComplete(_) => Ok(Bytes::default()), _ => Ok(Bytes::default()), // Handle other event types } @@ -307,13 +305,11 @@ async fn handle_streaming_response( async fn detect_metadata( resp: &mut aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, ) -> (bool, Option>) { - if let Ok(Some(event)) = resp.event_stream.recv().await { - if let PayloadChunk(chunk) = event { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); - let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; - return (has_metadata, Some(bytes)); - } + if let Ok(Some(PayloadChunk(chunk))) = resp.event_stream.recv().await { + if let Some(data) = chunk.payload() { + let bytes = data.clone().into_inner(); + let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; + return (has_metadata, Some(bytes)); } } (false, None) From 8e560e2ff9b2b50088050f94c23aea0772e60757 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 06:10:16 +0000 Subject: [PATCH 05/31] chore: reorganize test module structure and prevent `include!` macro --- src/config.rs | 8 +++----- src/{config_tests.rs => config/tests.rs} | 0 src/lib.rs | 8 +++----- src/{lib_tests.rs => tests.rs} | 0 4 files changed, 6 insertions(+), 10 deletions(-) rename src/{config_tests.rs => config/tests.rs} (100%) rename src/{lib_tests.rs => tests.rs} (100%) diff --git a/src/config.rs b/src/config.rs index 9340cd8..f2f8f1a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::fs; @@ -71,11 +74,6 @@ impl Config { } } -#[cfg(test)] -mod tests { - include!("config_tests.rs"); -} - fn default_addr() -> String { "0.0.0.0:8000".to_string() } diff --git a/src/config_tests.rs b/src/config/tests.rs similarity index 100% rename from src/config_tests.rs rename to src/config/tests.rs diff --git a/src/lib.rs b/src/lib.rs index 636b16d..2e22e6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,5 @@ -pub mod config; - #[cfg(test)] -mod tests { - include!("lib_tests.rs"); -} +mod tests; use crate::config::{Config, LambdaInvokeMode}; use aws_config::BehaviorVersion; @@ -30,6 +26,8 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; +pub mod config; + #[derive(Clone)] pub struct ApplicationState { client: Client, diff --git a/src/lib_tests.rs b/src/tests.rs similarity index 100% rename from src/lib_tests.rs rename to src/tests.rs From 4a7132122bf1a1816a7871d6fa03d198709ad53c Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 06:35:04 +0000 Subject: [PATCH 06/31] perf: prevent unnecessary clone --- src/lib.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2e22e6c..27b3341 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ use futures_util::stream::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; +use std::net::SocketAddr; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; @@ -42,15 +43,15 @@ pub async fn run_app() { let client = Client::new(&aws_config); let app_state = ApplicationState { client, config }; + let addr = app_state.config.addr.parse::().unwrap(); let app = Router::new() .route("/healthz", get(health)) .route("/", any(handler)) .route("/*path", any(handler)) .layer(TraceLayer::new_for_http()) - .with_state(app_state.clone()); + .with_state(app_state); - let addr = &app_state.config.addr; let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); tracing::info!("Listening on {}", addr); axum::serve(listener, app).await.unwrap(); From 1f14c03407a4934962b43bd986c6d37ccaafc0b0 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 06:56:33 +0000 Subject: [PATCH 07/31] perf: use Arc for Config in ApplicationState for cheaper clone --- src/lib.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 27b3341..04bd224 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::Arc; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; @@ -32,13 +33,13 @@ pub mod config; #[derive(Clone)] pub struct ApplicationState { client: Client, - config: Config, + config: Arc, } pub async fn run_app() { tracing_subscriber::fmt::init(); - let config = Config::load("config.yaml"); + let config = Arc::new(Config::load("config.yaml")); let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await; let client = Client::new(&aws_config); From 25cce5aac0f08f0f5425491474c6a2ede6e41052 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 31 Oct 2024 07:19:40 +0000 Subject: [PATCH 08/31] chore: add serial attribute to test_config_panic_on_empty_lambda_function_name --- src/config/tests.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/config/tests.rs b/src/config/tests.rs index 7d5b207..1f1a5f6 100644 --- a/src/config/tests.rs +++ b/src/config/tests.rs @@ -46,6 +46,7 @@ fn test_config_default() { } #[test] +#[serial] #[should_panic(expected = "No lambda_function_name provided")] fn test_config_panic_on_empty_lambda_function_name() { let mut config = Config::default(); From e779f11741b59a59351ca07f5242f6c5d1a32aa2 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 2 Jan 2025 07:00:15 +0000 Subject: [PATCH 09/31] chore: move tests to the end of files --- src/config.rs | 6 +++--- src/lib.rs | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/config.rs b/src/config.rs index f2f8f1a..cf67e55 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,3 @@ -#[cfg(test)] -mod tests; - use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::fs; @@ -115,3 +112,6 @@ impl FromStr for LambdaInvokeMode { } } } + +#[cfg(test)] +mod tests; diff --git a/src/lib.rs b/src/lib.rs index 04bd224..5344df1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ -#[cfg(test)] -mod tests; +pub mod config; use crate::config::{Config, LambdaInvokeMode}; use aws_config::BehaviorVersion; @@ -28,8 +27,6 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; -pub mod config; - #[derive(Clone)] pub struct ApplicationState { client: Client, @@ -365,3 +362,6 @@ fn process_buffer(buffer: &[u8]) -> (Option, Vec) { } (None, Vec::new()) } + +#[cfg(test)] +mod tests; From a87ec8b3ec29eadbfb21048b5c79530555f8ed15 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 2 Jan 2025 07:11:55 +0000 Subject: [PATCH 10/31] chore: optimize import structure, simplify typing --- src/config.rs | 15 ++++++--------- src/lib.rs | 53 ++++++++++++++++++++++++--------------------------- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/src/config.rs b/src/config.rs index cf67e55..0e8382d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,8 +1,5 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashSet; -use std::fs; -use std::path::Path; -use std::str::FromStr; +use std::{collections::HashSet, env, fs, path::Path, str::FromStr}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Config { @@ -40,26 +37,26 @@ impl Config { } fn apply_env_overrides(&mut self) { - if let Ok(val) = std::env::var("LAMBDA_FUNCTION_NAME") { + if let Ok(val) = env::var("LAMBDA_FUNCTION_NAME") { self.lambda_function_name = val; } if self.lambda_function_name.is_empty() { panic!("No lambda_function_name provided. Please set it in the config file or LAMBDA_FUNCTION_NAME environment variable."); } - if let Ok(val) = std::env::var("LAMBDA_INVOKE_MODE") { + if let Ok(val) = env::var("LAMBDA_INVOKE_MODE") { if let Ok(mode) = val.parse() { self.lambda_invoke_mode = mode; } } - if let Ok(val) = std::env::var("API_KEYS") { + if let Ok(val) = env::var("API_KEYS") { self.api_keys = val.split(',').filter(|s| !s.is_empty()).map(String::from).collect(); } - if let Ok(val) = std::env::var("AUTH_MODE") { + if let Ok(val) = env::var("AUTH_MODE") { if let Ok(mode) = val.parse() { self.auth_mode = mode; } } - if let Ok(val) = std::env::var("ADDR") { + if let Ok(val) = env::var("ADDR") { self.addr = val; } } diff --git a/src/lib.rs b/src/lib.rs index 5344df1..615bb9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,28 +2,31 @@ pub mod config; use crate::config::{Config, LambdaInvokeMode}; use aws_config::BehaviorVersion; -use aws_sdk_lambda::types::InvokeWithResponseStreamResponseEvent::{InvokeComplete, PayloadChunk}; -use aws_sdk_lambda::types::{InvokeResponseStreamUpdate, ResponseStreamingInvocationType}; -use aws_sdk_lambda::Client; +use aws_sdk_lambda::{ + operation::{invoke::InvokeOutput, invoke_with_response_stream::InvokeWithResponseStreamOutput}, + types::{ + InvokeResponseStreamUpdate, + InvokeWithResponseStreamResponseEvent::{InvokeComplete, PayloadChunk}, + ResponseStreamingInvocationType, + }, + Client, +}; use aws_smithy_types::Blob; -use axum::body::Body; use axum::{ - body::Bytes, + body::{Body, Bytes}, extract::{Path, Query, State}, http::{HeaderMap, Method, StatusCode}, response::{IntoResponse, Response}, - routing::any, - routing::get, + routing::{any, get}, Router, }; -use base64::Engine; +use base64::{prelude::BASE64_STANDARD, Engine}; +use config::AuthMode; use futures_util::stream::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; -use tokio::sync::mpsc; +use std::{collections::HashMap, convert::Infallible, net::SocketAddr, sync::Arc}; +use tokio::{net::TcpListener, sync::mpsc}; use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; @@ -50,7 +53,7 @@ pub async fn run_app() { .layer(TraceLayer::new_for_http()) .with_state(app_state); - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + let listener = TcpListener::bind(addr).await.unwrap(); tracing::info!("Listening on {}", addr); axum::serve(listener, app).await.unwrap(); } @@ -87,14 +90,14 @@ async fn handler( }; let body = if is_base64_encoded { - base64::engine::general_purpose::STANDARD.encode(body) + BASE64_STANDARD.encode(body) } else { String::from_utf8_lossy(&body).to_string() }; match config.auth_mode { - config::AuthMode::Open => {} - config::AuthMode::ApiKey => { + AuthMode::Open => {} + AuthMode::ApiKey => { let api_key = headers .get("x-api-key") .and_then(|v| v.to_str().ok()) @@ -189,7 +192,7 @@ struct MetadataPrelude { pub cookies: Vec, } -async fn handle_buffered_response(resp: aws_sdk_lambda::operation::invoke::InvokeOutput) -> Response { +async fn handle_buffered_response(resp: InvokeOutput) -> Response { // Parse the InvokeOutput payload to extract the LambdaResponse let payload = resp.payload().unwrap().as_ref().to_vec(); let lambda_response: LambdaResponse = serde_json::from_slice(&payload).unwrap(); @@ -204,18 +207,14 @@ async fn handle_buffered_response(resp: aws_sdk_lambda::operation::invoke::Invok } let body = if lambda_response.is_base64_encoded.unwrap_or(false) { - base64::engine::general_purpose::STANDARD - .decode(lambda_response.body) - .unwrap() + BASE64_STANDARD.decode(lambda_response.body).unwrap() } else { lambda_response.body.into_bytes() }; resp_builder.body(Body::from(body)).unwrap() } -async fn handle_streaming_response( - mut resp: aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, -) -> Response { +async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { let (tx, rx) = mpsc::channel(1); let mut metadata_buffer = Vec::new(); let mut metadata_prelude: Option = None; @@ -266,7 +265,7 @@ async fn handle_streaming_response( PayloadChunk(chunk) => { if let Some(data) = chunk.payload() { let bytes = data.clone().into_inner(); - Ok::<_, std::convert::Infallible>(Bytes::from(bytes)) + Ok::<_, Infallible>(Bytes::from(bytes)) } else { Ok(Bytes::default()) } @@ -299,9 +298,7 @@ async fn handle_streaming_response( resp_builder.body(Body::from_stream(stream)).unwrap() } -async fn detect_metadata( - resp: &mut aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, -) -> (bool, Option>) { +async fn detect_metadata(resp: &mut InvokeWithResponseStreamOutput) -> (bool, Option>) { if let Ok(Some(PayloadChunk(chunk))) = resp.event_stream.recv().await { if let Some(data) = chunk.payload() { let bytes = data.clone().into_inner(); @@ -313,7 +310,7 @@ async fn detect_metadata( } async fn collect_metadata( - resp: &mut aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, + resp: &mut InvokeWithResponseStreamOutput, metadata_buffer: &mut Vec, ) -> (Option, Vec) { let mut metadata_prelude = None; From 34860f4d24cbe47f307e0f640c2afc3ef0e381ed Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 2 Jan 2025 07:46:06 +0000 Subject: [PATCH 11/31] perf: better performance and error handling --- src/lib.rs | 111 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 63 insertions(+), 48 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 615bb9c..0f0375d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,21 @@ async fn health() -> impl IntoResponse { StatusCode::OK } +macro_rules! handle_err { + ($name:expr, $result:expr) => {{ + match $result { + Ok(v) => v, + Err(e) => { + tracing::error!("{}: {:?}", $name, e); + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap(); + } + } + }}; +} + async fn handler( path: Option>, Query(query_string_parameters): Query>, @@ -134,24 +149,28 @@ async fn handler( match config.lambda_invoke_mode { LambdaInvokeMode::Buffered => { - let resp = client - .invoke() - .function_name(config.lambda_function_name.as_str()) - .payload(Blob::new(lambda_request_body)) - .send() - .await - .unwrap(); + let resp = handle_err!( + "Invoking lambda", + client + .invoke() + .function_name(config.lambda_function_name.as_str()) + .payload(Blob::new(lambda_request_body)) + .send() + .await + ); handle_buffered_response(resp).await } LambdaInvokeMode::ResponseStream => { - let resp = client - .invoke_with_response_stream() - .function_name(config.lambda_function_name.as_str()) - .invocation_type(ResponseStreamingInvocationType::RequestResponse) - .payload(Blob::new(lambda_request_body)) - .send() - .await - .unwrap(); + let resp = handle_err!( + "Invoking lambda", + client + .invoke_with_response_stream() + .function_name(config.lambda_function_name.as_str()) + .invocation_type(ResponseStreamingInvocationType::RequestResponse) + .payload(Blob::new(lambda_request_body)) + .send() + .await + ); handle_streaming_response(resp).await } } @@ -194,8 +213,11 @@ struct MetadataPrelude { async fn handle_buffered_response(resp: InvokeOutput) -> Response { // Parse the InvokeOutput payload to extract the LambdaResponse - let payload = resp.payload().unwrap().as_ref().to_vec(); - let lambda_response: LambdaResponse = serde_json::from_slice(&payload).unwrap(); + let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); + let lambda_response = handle_err!( + "Deserializing lambda response", + serde_json::from_slice::(payload) + ); // Build the response using the extracted information let mut resp_builder = Response::builder().status(StatusCode::from_u16(lambda_response.status_code).unwrap()); @@ -207,11 +229,14 @@ async fn handle_buffered_response(resp: InvokeOutput) -> Response { } let body = if lambda_response.is_base64_encoded.unwrap_or(false) { - BASE64_STANDARD.decode(lambda_response.body).unwrap() + handle_err!( + "Decode base64 lambda response body", + BASE64_STANDARD.decode(lambda_response.body) + ) } else { lambda_response.body.into_bytes() }; - resp_builder.body(Body::from(body)).unwrap() + handle_err!("Building response", resp_builder.body(Body::from(body))) } async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { @@ -244,34 +269,24 @@ async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> let _ = tx.send(PayloadChunk(stream_update)).await; } - while let Some(event) = resp.event_stream.recv().await.unwrap() { - match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload() { - let stream_update = InvokeResponseStreamUpdate::builder().payload(data.clone()).build(); - let _ = tx.send(PayloadChunk(stream_update)).await; - } - } - InvokeComplete(_) => { - let _ = tx.send(event).await; - } - _ => {} - } + while let Ok(Some(event)) = resp.event_stream.recv().await { + tx.send(event).await.ok(); } }); - let stream = ReceiverStream::new(rx).map(|event| { - match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); - Ok::<_, Infallible>(Bytes::from(bytes)) - } else { - Ok(Bytes::default()) - } + let stream = ReceiverStream::new(rx).map(|event| match event { + PayloadChunk(chunk) => { + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); + Ok::<_, Infallible>(Bytes::from(bytes)) + } else { + Ok(Bytes::default()) } - InvokeComplete(_) => Ok(Bytes::default()), - _ => Ok(Bytes::default()), // Handle other event types + } + InvokeComplete(_) => Ok(Bytes::default()), + _ => { + tracing::warn!("Unhandled event type: {:?}", event); + Ok(Bytes::default()) } }); @@ -295,13 +310,13 @@ async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> resp_builder = resp_builder.header("content-type", "application/octet-stream"); } - resp_builder.body(Body::from_stream(stream)).unwrap() + handle_err!("Building response", resp_builder.body(Body::from_stream(stream))) } async fn detect_metadata(resp: &mut InvokeWithResponseStreamOutput) -> (bool, Option>) { if let Ok(Some(PayloadChunk(chunk))) = resp.event_stream.recv().await { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; return (has_metadata, Some(bytes)); } @@ -325,8 +340,8 @@ async fn collect_metadata( // If metadata is not complete, continue processing the stream while let Ok(Some(event)) = resp.event_stream.recv().await { if let PayloadChunk(chunk) = event { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); metadata_buffer.extend_from_slice(&bytes); let (prelude, remaining) = process_buffer(metadata_buffer); if let Some(p) = prelude { From 55825dc07e7a8bce36dc516de710cef6ed855291 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 2 Jan 2025 07:57:35 +0000 Subject: [PATCH 12/31] chore: split mods --- src/lib.rs | 186 ++--------------------------------------------- src/streaming.rs | 166 ++++++++++++++++++++++++++++++++++++++++++ src/utils.rs | 15 ++++ 3 files changed, 189 insertions(+), 178 deletions(-) create mode 100644 src/streaming.rs create mode 100644 src/utils.rs diff --git a/src/lib.rs b/src/lib.rs index 0f0375d..de4ab51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,16 +1,10 @@ -pub mod config; +mod config; +mod streaming; +mod utils; use crate::config::{Config, LambdaInvokeMode}; use aws_config::BehaviorVersion; -use aws_sdk_lambda::{ - operation::{invoke::InvokeOutput, invoke_with_response_stream::InvokeWithResponseStreamOutput}, - types::{ - InvokeResponseStreamUpdate, - InvokeWithResponseStreamResponseEvent::{InvokeComplete, PayloadChunk}, - ResponseStreamingInvocationType, - }, - Client, -}; +use aws_sdk_lambda::{operation::invoke::InvokeOutput, types::ResponseStreamingInvocationType, Client}; use aws_smithy_types::Blob; use axum::{ body::{Body, Bytes}, @@ -22,13 +16,13 @@ use axum::{ }; use base64::{prelude::BASE64_STANDARD, Engine}; use config::AuthMode; -use futures_util::stream::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::{collections::HashMap, convert::Infallible, net::SocketAddr, sync::Arc}; -use tokio::{net::TcpListener, sync::mpsc}; -use tokio_stream::wrappers::ReceiverStream; +use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use streaming::handle_streaming_response; +use tokio::net::TcpListener; use tower_http::trace::TraceLayer; +use utils::handle_err; #[derive(Clone)] pub struct ApplicationState { @@ -62,21 +56,6 @@ async fn health() -> impl IntoResponse { StatusCode::OK } -macro_rules! handle_err { - ($name:expr, $result:expr) => {{ - match $result { - Ok(v) => v, - Err(e) => { - tracing::error!("{}: {:?}", $name, e); - return Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap(); - } - } - }}; -} - async fn handler( path: Option>, Query(query_string_parameters): Query>, @@ -198,19 +177,6 @@ struct LambdaResponse { body: String, } -#[derive(Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct MetadataPrelude { - #[serde(with = "http_serde::status_code")] - /// The HTTP status code. - pub status_code: StatusCode, - #[serde(with = "http_serde::header_map")] - /// The HTTP headers. - pub headers: HeaderMap, - /// The HTTP cookies. - pub cookies: Vec, -} - async fn handle_buffered_response(resp: InvokeOutput) -> Response { // Parse the InvokeOutput payload to extract the LambdaResponse let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); @@ -239,141 +205,5 @@ async fn handle_buffered_response(resp: InvokeOutput) -> Response { handle_err!("Building response", resp_builder.body(Body::from(body))) } -async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { - let (tx, rx) = mpsc::channel(1); - let mut metadata_buffer = Vec::new(); - let mut metadata_prelude: Option = None; - let mut remaining_data = Vec::new(); - - // Step 1: Detect if metadata exists and get the first chunk - let (has_metadata, first_chunk) = detect_metadata(&mut resp).await; - - // Step 2: Process the first chunk - if let Some(chunk) = first_chunk { - if has_metadata { - metadata_buffer.extend_from_slice(&chunk); - (metadata_prelude, remaining_data) = collect_metadata(&mut resp, &mut metadata_buffer).await; - } else { - // No metadata prelude, treat first chunk as payload - remaining_data = chunk; - } - } - - // Spawn task to handle remaining stream - tokio::spawn(async move { - // Send remaining data after metadata first - if !remaining_data.is_empty() { - let stream_update = InvokeResponseStreamUpdate::builder() - .payload(Blob::new(remaining_data)) - .build(); - let _ = tx.send(PayloadChunk(stream_update)).await; - } - - while let Ok(Some(event)) = resp.event_stream.recv().await { - tx.send(event).await.ok(); - } - }); - - let stream = ReceiverStream::new(rx).map(|event| match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload { - let bytes = data.into_inner(); - Ok::<_, Infallible>(Bytes::from(bytes)) - } else { - Ok(Bytes::default()) - } - } - InvokeComplete(_) => Ok(Bytes::default()), - _ => { - tracing::warn!("Unhandled event type: {:?}", event); - Ok(Bytes::default()) - } - }); - - let mut resp_builder = Response::builder(); - - if let Some(metadata_prelude) = metadata_prelude { - resp_builder = resp_builder.status(metadata_prelude.status_code); - - for (k, v) in metadata_prelude.headers.iter() { - if k != "content-length" { - resp_builder = resp_builder.header(k, v); - } - } - - for cookie in &metadata_prelude.cookies { - resp_builder = resp_builder.header("set-cookie", cookie); - } - } else { - // Default response if no metadata - resp_builder = resp_builder.status(StatusCode::OK); - resp_builder = resp_builder.header("content-type", "application/octet-stream"); - } - - handle_err!("Building response", resp_builder.body(Body::from_stream(stream))) -} - -async fn detect_metadata(resp: &mut InvokeWithResponseStreamOutput) -> (bool, Option>) { - if let Ok(Some(PayloadChunk(chunk))) = resp.event_stream.recv().await { - if let Some(data) = chunk.payload { - let bytes = data.into_inner(); - let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; - return (has_metadata, Some(bytes)); - } - } - (false, None) -} - -async fn collect_metadata( - resp: &mut InvokeWithResponseStreamOutput, - metadata_buffer: &mut Vec, -) -> (Option, Vec) { - let mut metadata_prelude = None; - let mut remaining_data = Vec::new(); - - // Process the metadata_buffer first - let (prelude, remaining) = process_buffer(metadata_buffer); - if let Some(p) = prelude { - return (Some(p), remaining); - } - - // If metadata is not complete, continue processing the stream - while let Ok(Some(event)) = resp.event_stream.recv().await { - if let PayloadChunk(chunk) = event { - if let Some(data) = chunk.payload { - let bytes = data.into_inner(); - metadata_buffer.extend_from_slice(&bytes); - let (prelude, remaining) = process_buffer(metadata_buffer); - if let Some(p) = prelude { - metadata_prelude = Some(p); - remaining_data = remaining; - break; - } - } - } - } - (metadata_prelude, remaining_data) -} - -fn process_buffer(buffer: &[u8]) -> (Option, Vec) { - let mut null_count = 0; - for (i, &byte) in buffer.iter().enumerate() { - if byte == 0 { - null_count += 1; - if null_count == 8 { - let metadata_str = String::from_utf8_lossy(&buffer[..i]); - let metadata_prelude = serde_json::from_str(&metadata_str).unwrap_or_default(); - tracing::debug!(metadata_prelude=?metadata_prelude); - // Save remaining data after metadata - let remaining_data = buffer[i + 1..].to_vec(); - return (Some(metadata_prelude), remaining_data); - } - } else { - null_count = 0; - } - } - (None, Vec::new()) -} - #[cfg(test)] mod tests; diff --git a/src/streaming.rs b/src/streaming.rs new file mode 100644 index 0000000..0178d8d --- /dev/null +++ b/src/streaming.rs @@ -0,0 +1,166 @@ +use crate::utils::handle_err; +use aws_sdk_lambda::{ + operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, + types::{InvokeResponseStreamUpdate, InvokeWithResponseStreamResponseEvent}, +}; +use aws_smithy_types::Blob; +use axum::{ + body::Body, + http::{HeaderMap, StatusCode}, + response::Response, +}; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::convert::Infallible; +use tokio::sync::mpsc; +use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use InvokeWithResponseStreamResponseEvent::*; + +#[derive(Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct MetadataPrelude { + #[serde(with = "http_serde::status_code")] + /// The HTTP status code. + pub status_code: StatusCode, + #[serde(with = "http_serde::header_map")] + /// The HTTP headers. + pub headers: HeaderMap, + /// The HTTP cookies. + pub cookies: Vec, +} + +pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { + let (tx, rx) = mpsc::channel(1); + let mut metadata_buffer = Vec::new(); + let mut metadata_prelude: Option = None; + let mut remaining_data = Vec::new(); + + // Step 1: Detect if metadata exists and get the first chunk + let (has_metadata, first_chunk) = detect_metadata(&mut resp).await; + + // Step 2: Process the first chunk + if let Some(chunk) = first_chunk { + if has_metadata { + metadata_buffer.extend_from_slice(&chunk); + (metadata_prelude, remaining_data) = collect_metadata(&mut resp, &mut metadata_buffer).await; + } else { + // No metadata prelude, treat first chunk as payload + remaining_data = chunk; + } + } + + // Spawn task to handle remaining stream + tokio::spawn(async move { + // Send remaining data after metadata first + if !remaining_data.is_empty() { + let stream_update = InvokeResponseStreamUpdate::builder() + .payload(Blob::new(remaining_data)) + .build(); + let _ = tx.send(PayloadChunk(stream_update)).await; + } + + while let Ok(Some(event)) = resp.event_stream.recv().await { + tx.send(event).await.ok(); + } + }); + + let stream = ReceiverStream::new(rx).map(|event| match event { + PayloadChunk(chunk) => { + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); + Ok::<_, Infallible>(Bytes::from(bytes)) + } else { + Ok(Bytes::default()) + } + } + InvokeComplete(_) => Ok(Bytes::default()), + _ => { + tracing::warn!("Unhandled event type: {:?}", event); + Ok(Bytes::default()) + } + }); + + let mut resp_builder = Response::builder(); + + if let Some(metadata_prelude) = metadata_prelude { + resp_builder = resp_builder.status(metadata_prelude.status_code); + + for (k, v) in metadata_prelude.headers.iter() { + if k != "content-length" { + resp_builder = resp_builder.header(k, v); + } + } + + for cookie in &metadata_prelude.cookies { + resp_builder = resp_builder.header("set-cookie", cookie); + } + } else { + // Default response if no metadata + resp_builder = resp_builder.status(StatusCode::OK); + resp_builder = resp_builder.header("content-type", "application/octet-stream"); + } + + handle_err!("Building response", resp_builder.body(Body::from_stream(stream))) +} + +async fn detect_metadata(resp: &mut InvokeWithResponseStreamOutput) -> (bool, Option>) { + if let Ok(Some(PayloadChunk(chunk))) = resp.event_stream.recv().await { + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); + let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; + return (has_metadata, Some(bytes)); + } + } + (false, None) +} + +async fn collect_metadata( + resp: &mut InvokeWithResponseStreamOutput, + metadata_buffer: &mut Vec, +) -> (Option, Vec) { + let mut metadata_prelude = None; + let mut remaining_data = Vec::new(); + + // Process the metadata_buffer first + let (prelude, remaining) = process_buffer(metadata_buffer); + if let Some(p) = prelude { + return (Some(p), remaining); + } + + // If metadata is not complete, continue processing the stream + while let Ok(Some(event)) = resp.event_stream.recv().await { + if let PayloadChunk(chunk) = event { + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); + metadata_buffer.extend_from_slice(&bytes); + let (prelude, remaining) = process_buffer(metadata_buffer); + if let Some(p) = prelude { + metadata_prelude = Some(p); + remaining_data = remaining; + break; + } + } + } + } + (metadata_prelude, remaining_data) +} + +fn process_buffer(buffer: &[u8]) -> (Option, Vec) { + let mut null_count = 0; + for (i, &byte) in buffer.iter().enumerate() { + if byte == 0 { + null_count += 1; + if null_count == 8 { + let metadata_str = String::from_utf8_lossy(&buffer[..i]); + let metadata_prelude = serde_json::from_str(&metadata_str).unwrap_or_default(); + tracing::debug!(metadata_prelude=?metadata_prelude); + // Save remaining data after metadata + let remaining_data = buffer[i + 1..].to_vec(); + return (Some(metadata_prelude), remaining_data); + } + } else { + null_count = 0; + } + } + (None, Vec::new()) +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..43beec8 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,15 @@ +macro_rules! handle_err { + ($name:expr, $result:expr) => {{ + match $result { + Ok(v) => v, + Err(e) => { + tracing::error!("{}: {:?}", $name, e); + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap(); + } + } + }}; +} +pub(crate) use handle_err; From 493f9420848f76a421ace3601085f884499e43dd Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Thu, 2 Jan 2025 08:54:52 +0000 Subject: [PATCH 13/31] chore: simplify code with aws_lambda_events crate --- Cargo.lock | 301 ++++++++++++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 1 + src/lib.rs | 99 ++++++++---------- 3 files changed, 338 insertions(+), 63 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8159ddb..cf49f68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +32,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.15" @@ -455,6 +476,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws_lambda_events" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ddb91585253ccc85be3f2e0d5635529efdeadaf8a1da3230b433d3bbe43648" +dependencies = [ + "base64 0.22.0", + "bytes", + "chrono", + "flate2", + "http 1.1.0", + "http-body 1.0.0", + "http-serde", + "query_map", + "serde", + "serde_dynamo", + "serde_json", + "serde_with", +] + [[package]] name = "axum" version = "0.7.5" @@ -520,7 +561,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.2", "object", "rustc-demangle", ] @@ -591,11 +632,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "bytes" version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +dependencies = [ + "serde", +] [[package]] name = "bytes-utils" @@ -633,6 +683,19 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-targets 0.52.6", +] + [[package]] name = "clang-sys" version = "1.7.0" @@ -743,6 +806,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.60", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.60", +] + [[package]] name = "deranged" version = "0.3.11" @@ -750,6 +848,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -797,6 +896,16 @@ version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.2", +] + [[package]] name = "fnv" version = "1.0.7" @@ -952,13 +1061,19 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap", + "indexmap 2.2.6", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.3" @@ -1154,6 +1269,35 @@ dependencies = [ "tokio", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1164,6 +1308,17 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.2.6" @@ -1171,7 +1326,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.3", + "serde", ] [[package]] @@ -1210,6 +1366,16 @@ dependencies = [ "libc", ] +[[package]] +name = "js-sys" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + [[package]] name = "lambda-web-gateway" version = "0.1.0" @@ -1217,6 +1383,7 @@ dependencies = [ "aws-config", "aws-sdk-lambda", "aws-smithy-types", + "aws_lambda_events", "axum", "base64 0.22.0", "bytes", @@ -1322,6 +1489,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "1.0.2" @@ -1509,6 +1685,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "query_map" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eab6b8b1074ef3359a863758dae650c7c0c6027927a085b7af911c8e0bf3a15" +dependencies = [ + "form_urlencoded", + "serde", + "serde_derive", +] + [[package]] name = "quote" version = "1.0.36" @@ -1788,6 +1975,16 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "serde_dynamo" +version = "4.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e36c1b1792cfd9de29eb373ee6a4b74650369c096f55db7198ceb7b8921d1f7f" +dependencies = [ + "base64 0.21.7", + "serde", +] + [[package]] name = "serde_json" version = "1.0.116" @@ -1821,13 +2018,43 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +dependencies = [ + "base64 0.22.0", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.6", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "serde_yaml" version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap", + "indexmap 2.2.6", "itoa 1.0.11", "ryu", "serde", @@ -1998,6 +2225,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", + "itoa 1.0.11", "num-conv", "powerfmt", "serde", @@ -2329,6 +2557,60 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn 2.0.60", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" + [[package]] name = "which" version = "4.4.2" @@ -2363,6 +2645,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index c5ebae2..e776b6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ tracing = { version = "0.1.40" } tokio-stream = "0.1.15" futures-util = "0.3.30" http-serde = "2.1.1" +aws_lambda_events = "0.16.0" [dev-dependencies] tempfile = "3.8.1" diff --git a/src/lib.rs b/src/lib.rs index de4ab51..4fd423e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,10 @@ mod utils; use crate::config::{Config, LambdaInvokeMode}; use aws_config::BehaviorVersion; +use aws_lambda_events::{ + alb::{AlbTargetGroupRequest, AlbTargetGroupRequestContext, AlbTargetGroupResponse, ElbContext}, + query_map::QueryMap, +}; use aws_sdk_lambda::{operation::invoke::InvokeOutput, types::ResponseStreamingInvocationType, Client}; use aws_smithy_types::Blob; use axum::{ @@ -16,9 +20,7 @@ use axum::{ }; use base64::{prelude::BASE64_STANDARD, Engine}; use config::AuthMode; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc}; use streaming::handle_streaming_response; use tokio::net::TcpListener; use tower_http::trace::TraceLayer; @@ -58,9 +60,9 @@ async fn health() -> impl IntoResponse { async fn handler( path: Option>, - Query(query_string_parameters): Query>, + Query(query_string_parameters): Query, State(state): State, - method: Method, + http_method: Method, headers: HeaderMap, body: Bytes, ) -> Response { @@ -68,8 +70,6 @@ async fn handler( let config = &state.config; let path = "/".to_string() + path.map(|p| p.0).unwrap_or_default().as_str(); - let http_method = method.to_string(); - let content_type = headers .get("content-type") .and_then(|v| v.to_str().ok()) @@ -111,20 +111,24 @@ async fn handler( } } - let lambda_request_body = json!({ - "httpMethod": http_method, - "headers": to_string_map(&headers), - "path": path, - "queryStringParameters": query_string_parameters, - "isBase64Encoded": is_base64_encoded, - "body": body, - "requestContext": { - "elb": { - "targetGroupArn": "", + let lambda_request_body = handle_err!( + "Building lambda request", + serde_json::to_string(&AlbTargetGroupRequest { + http_method, + headers, + path: path.into(), + query_string_parameters, + body: body.into(), + is_base64_encoded, + request_context: AlbTargetGroupRequestContext { + elb: ElbContext { target_group_arn: None }, }, - }, - }) - .to_string(); + // TODO: remove these? + // https://github.com/awslabs/aws-lambda-rust-runtime/issues/953 + multi_value_headers: Default::default(), + multi_value_query_string_parameters: Default::default(), + }) + ); match config.lambda_invoke_mode { LambdaInvokeMode::Buffered => { @@ -155,53 +159,32 @@ async fn handler( } } -fn to_string_map(headers: &HeaderMap) -> HashMap { - headers - .iter() - .map(|(k, v)| { - ( - k.as_str().to_owned(), - String::from_utf8_lossy(v.as_bytes()).into_owned(), - ) - }) - .collect() -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "camelCase")] -struct LambdaResponse { - status_code: u16, - status_description: Option, - is_base64_encoded: Option, - headers: Option>, - body: String, -} - async fn handle_buffered_response(resp: InvokeOutput) -> Response { // Parse the InvokeOutput payload to extract the LambdaResponse let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); let lambda_response = handle_err!( "Deserializing lambda response", - serde_json::from_slice::(payload) + serde_json::from_slice::(payload) ); // Build the response using the extracted information - let mut resp_builder = Response::builder().status(StatusCode::from_u16(lambda_response.status_code).unwrap()); - - if let Some(headers) = lambda_response.headers { - for (key, value) in headers { - resp_builder = resp_builder.header(key, value); - } + let mut resp_builder = Response::builder().status(handle_err!( + "Parse response status code", + StatusCode::from_u16(handle_err!( + "Parse response status code", + lambda_response.status_code.try_into() + )) + )); + + *handle_err!( + "Setting response headers", + resp_builder.headers_mut().ok_or("Errors in builder") + ) = lambda_response.headers; + + let mut body = lambda_response.body.map_or(vec![], |b| b.to_vec()); + if lambda_response.is_base64_encoded { + body = handle_err!("Decoding base64 body", BASE64_STANDARD.decode(body)); } - - let body = if lambda_response.is_base64_encoded.unwrap_or(false) { - handle_err!( - "Decode base64 lambda response body", - BASE64_STANDARD.decode(lambda_response.body) - ) - } else { - lambda_response.body.into_bytes() - }; handle_err!("Building response", resp_builder.body(Body::from(body))) } From 5616d4c7505864ec9c1120174461b9dc770971fa Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 03:09:34 +0000 Subject: [PATCH 14/31] chore: update comments --- src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4fd423e..9223cab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,8 +123,7 @@ async fn handler( request_context: AlbTargetGroupRequestContext { elb: ElbContext { target_group_arn: None }, }, - // TODO: remove these? - // https://github.com/awslabs/aws-lambda-rust-runtime/issues/953 + // TODO: support multi-value-header mode? multi_value_headers: Default::default(), multi_value_query_string_parameters: Default::default(), }) From 2900971d57ebbf99b0b8db81112cdf06d7e17d37 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 03:16:21 +0000 Subject: [PATCH 15/31] chore: remove unnecessary statements --- src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9223cab..6af15ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,7 @@ use aws_lambda_events::{ alb::{AlbTargetGroupRequest, AlbTargetGroupRequestContext, AlbTargetGroupResponse, ElbContext}, query_map::QueryMap, }; -use aws_sdk_lambda::{operation::invoke::InvokeOutput, types::ResponseStreamingInvocationType, Client}; +use aws_sdk_lambda::{operation::invoke::InvokeOutput, Client}; use aws_smithy_types::Blob; use axum::{ body::{Body, Bytes}, @@ -148,7 +148,6 @@ async fn handler( client .invoke_with_response_stream() .function_name(config.lambda_function_name.as_str()) - .invocation_type(ResponseStreamingInvocationType::RequestResponse) .payload(Blob::new(lambda_request_body)) .send() .await From f8eb4ba48073b5564bfe89e1cc603ac10132da8c Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 06:10:26 +0000 Subject: [PATCH 16/31] chore: extract mods and functions --- src/auth.rs | 21 +++++++ src/buffered.rs | 34 +++++++++++ src/lib.rs | 150 ++++++++++-------------------------------------- src/request.rs | 27 +++++++++ src/utils.rs | 27 +++++++++ 5 files changed, 140 insertions(+), 119 deletions(-) create mode 100644 src/auth.rs create mode 100644 src/buffered.rs create mode 100644 src/request.rs diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..ea3e5aa --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,21 @@ +use crate::config::{AuthMode, Config}; +use axum::http::HeaderMap; + +pub(super) fn is_authorized(headers: &HeaderMap, config: &Config) -> bool { + match config.auth_mode { + AuthMode::Open => true, + AuthMode::ApiKey => { + let api_key = headers + .get("x-api-key") + .and_then(|v| v.to_str().ok()) + .or_else(|| { + headers + .get("authorization") + .and_then(|v| v.to_str().ok().and_then(|s| s.strip_prefix("Bearer "))) + }) + .unwrap_or_default(); + + config.api_keys.contains(api_key) + } + } +} diff --git a/src/buffered.rs b/src/buffered.rs new file mode 100644 index 0000000..7d05122 --- /dev/null +++ b/src/buffered.rs @@ -0,0 +1,34 @@ +use crate::utils::handle_err; +use aws_lambda_events::alb::AlbTargetGroupResponse; +use aws_sdk_lambda::operation::invoke::InvokeOutput; +use axum::{body::Body, http::StatusCode, response::Response}; +use base64::{prelude::BASE64_STANDARD, Engine}; + +pub(super) async fn handle_buffered_response(resp: InvokeOutput) -> Response { + // Parse the InvokeOutput payload to extract the LambdaResponse + let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); + let lambda_response = handle_err!( + "Deserializing lambda response", + serde_json::from_slice::(payload) + ); + + // Build the response using the extracted information + let mut resp_builder = Response::builder().status(handle_err!( + "Parse response status code", + StatusCode::from_u16(handle_err!( + "Parse response status code", + lambda_response.status_code.try_into() + )) + )); + + *handle_err!( + "Setting response headers", + resp_builder.headers_mut().ok_or("Errors in builder") + ) = lambda_response.headers; + + let mut body = lambda_response.body.map_or(vec![], |b| b.to_vec()); + if lambda_response.is_base64_encoded { + body = handle_err!("Decoding base64 body", BASE64_STANDARD.decode(body)); + } + handle_err!("Building response", resp_builder.body(Body::from(body))) +} diff --git a/src/lib.rs b/src/lib.rs index 6af15ff..84a3657 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,30 +1,31 @@ +mod auth; +mod buffered; mod config; +mod request; mod streaming; mod utils; use crate::config::{Config, LambdaInvokeMode}; +use auth::is_authorized; use aws_config::BehaviorVersion; -use aws_lambda_events::{ - alb::{AlbTargetGroupRequest, AlbTargetGroupRequestContext, AlbTargetGroupResponse, ElbContext}, - query_map::QueryMap, -}; -use aws_sdk_lambda::{operation::invoke::InvokeOutput, Client}; +use aws_lambda_events::query_map::QueryMap; +use aws_sdk_lambda::Client; use aws_smithy_types::Blob; use axum::{ body::{Body, Bytes}, - extract::{Path, Query, State}, - http::{HeaderMap, Method, StatusCode}, + extract::{Query, State}, + http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, routing::{any, get}, Router, }; -use base64::{prelude::BASE64_STANDARD, Engine}; -use config::AuthMode; +use buffered::handle_buffered_response; +use request::build_alb_request_body; use std::{net::SocketAddr, sync::Arc}; use streaming::handle_streaming_response; use tokio::net::TcpListener; use tower_http::trace::TraceLayer; -use utils::handle_err; +use utils::{handle_err, transform_body, whether_base64_encoded}; #[derive(Clone)] pub struct ApplicationState { @@ -59,131 +60,42 @@ async fn health() -> impl IntoResponse { } async fn handler( - path: Option>, - Query(query_string_parameters): Query, State(state): State, - http_method: Method, - headers: HeaderMap, + Query(query): Query, + parts: Parts, body: Bytes, ) -> Response { - let client = &state.client; - let config = &state.config; - let path = "/".to_string() + path.map(|p| p.0).unwrap_or_default().as_str(); - - let content_type = headers - .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or_default(); - - let is_base64_encoded = match content_type { - "application/json" => false, - "application/xml" => false, - "application/javascript" => false, - _ if content_type.starts_with("text/") => false, - _ => true, - }; - - let body = if is_base64_encoded { - BASE64_STANDARD.encode(body) - } else { - String::from_utf8_lossy(&body).to_string() - }; - - match config.auth_mode { - AuthMode::Open => {} - AuthMode::ApiKey => { - let api_key = headers - .get("x-api-key") - .and_then(|v| v.to_str().ok()) - .or_else(|| { - headers - .get("authorization") - .and_then(|v| v.to_str().ok().and_then(|s| s.strip_prefix("Bearer "))) - }) - .unwrap_or_default(); - - if !config.api_keys.contains(api_key) { - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::empty()) - .unwrap(); - } - } + if !is_authorized(&parts.headers, &state.config) { + return StatusCode::UNAUTHORIZED.into_response(); } + let is_base64_encoded = whether_base64_encoded(&parts.headers); + let body = transform_body(is_base64_encoded, body); + let lambda_request_body = handle_err!( "Building lambda request", - serde_json::to_string(&AlbTargetGroupRequest { - http_method, - headers, - path: path.into(), - query_string_parameters, - body: body.into(), - is_base64_encoded, - request_context: AlbTargetGroupRequestContext { - elb: ElbContext { target_group_arn: None }, - }, - // TODO: support multi-value-header mode? - multi_value_headers: Default::default(), - multi_value_query_string_parameters: Default::default(), - }) + build_alb_request_body(is_base64_encoded, query, parts, body) ); - match config.lambda_invoke_mode { - LambdaInvokeMode::Buffered => { - let resp = handle_err!( - "Invoking lambda", - client - .invoke() - .function_name(config.lambda_function_name.as_str()) - .payload(Blob::new(lambda_request_body)) - .send() - .await - ); - handle_buffered_response(resp).await - } - LambdaInvokeMode::ResponseStream => { - let resp = handle_err!( + macro_rules! call_lambda { + ($action:ident) => { + handle_err!( "Invoking lambda", - client - .invoke_with_response_stream() - .function_name(config.lambda_function_name.as_str()) + state + .client + .$action() + .function_name(state.config.lambda_function_name.as_str()) .payload(Blob::new(lambda_request_body)) .send() .await - ); - handle_streaming_response(resp).await - } + ) + }; } -} - -async fn handle_buffered_response(resp: InvokeOutput) -> Response { - // Parse the InvokeOutput payload to extract the LambdaResponse - let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); - let lambda_response = handle_err!( - "Deserializing lambda response", - serde_json::from_slice::(payload) - ); - - // Build the response using the extracted information - let mut resp_builder = Response::builder().status(handle_err!( - "Parse response status code", - StatusCode::from_u16(handle_err!( - "Parse response status code", - lambda_response.status_code.try_into() - )) - )); - - *handle_err!( - "Setting response headers", - resp_builder.headers_mut().ok_or("Errors in builder") - ) = lambda_response.headers; - let mut body = lambda_response.body.map_or(vec![], |b| b.to_vec()); - if lambda_response.is_base64_encoded { - body = handle_err!("Decoding base64 body", BASE64_STANDARD.decode(body)); + match state.config.lambda_invoke_mode { + LambdaInvokeMode::Buffered => handle_buffered_response(call_lambda!(invoke)).await, + LambdaInvokeMode::ResponseStream => handle_streaming_response(call_lambda!(invoke_with_response_stream)).await, } - handle_err!("Building response", resp_builder.body(Body::from(body))) } #[cfg(test)] diff --git a/src/request.rs b/src/request.rs new file mode 100644 index 0000000..35675d1 --- /dev/null +++ b/src/request.rs @@ -0,0 +1,27 @@ +use aws_lambda_events::{ + alb::{AlbTargetGroupRequest, AlbTargetGroupRequestContext, ElbContext}, + query_map::QueryMap, +}; +use axum::http::request::Parts; + +pub(super) fn build_alb_request_body( + is_base64_encoded: bool, + query_string_parameters: QueryMap, + parts: Parts, + body: String, +) -> Result { + serde_json::to_string(&AlbTargetGroupRequest { + http_method: parts.method, + headers: parts.headers, + path: parts.uri.path().to_string().into(), + query_string_parameters, + body: body.into(), + is_base64_encoded, + request_context: AlbTargetGroupRequestContext { + elb: ElbContext { target_group_arn: None }, + }, + // TODO: support multi-value-header mode? + multi_value_headers: Default::default(), + multi_value_query_string_parameters: Default::default(), + }) +} diff --git a/src/utils.rs b/src/utils.rs index 43beec8..5eb5842 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,7 @@ +use axum::http::HeaderMap; +use base64::{prelude::BASE64_STANDARD, Engine}; +use bytes::Bytes; + macro_rules! handle_err { ($name:expr, $result:expr) => {{ match $result { @@ -13,3 +17,26 @@ macro_rules! handle_err { }}; } pub(crate) use handle_err; + +pub(super) fn whether_base64_encoded(headers: &HeaderMap) -> bool { + let content_type = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or_default(); + + match content_type { + "application/json" => false, + "application/xml" => false, + "application/javascript" => false, + _ if content_type.starts_with("text/") => false, + _ => true, + } +} + +pub(super) fn transform_body(is_base64_encoded: bool, body: Bytes) -> String { + if is_base64_encoded { + BASE64_STANDARD.encode(body) + } else { + String::from_utf8_lossy(&body).to_string() + } +} From c5d5650844d859ac3eca31b9b8bed2d215ad9184 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 06:24:39 +0000 Subject: [PATCH 17/31] chore: inline `run_app` to `main` --- src/lib.rs | 40 +++++++--------------------------------- src/main.rs | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 84a3657..7335b74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,9 @@ mod request; mod streaming; mod utils; -use crate::config::{Config, LambdaInvokeMode}; +pub use config::*; + use auth::is_authorized; -use aws_config::BehaviorVersion; use aws_lambda_events::query_map::QueryMap; use aws_sdk_lambda::Client; use aws_smithy_types::Blob; @@ -16,50 +16,24 @@ use axum::{ extract::{Query, State}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, - routing::{any, get}, - Router, }; use buffered::handle_buffered_response; use request::build_alb_request_body; -use std::{net::SocketAddr, sync::Arc}; +use std::sync::Arc; use streaming::handle_streaming_response; -use tokio::net::TcpListener; -use tower_http::trace::TraceLayer; use utils::{handle_err, transform_body, whether_base64_encoded}; #[derive(Clone)] pub struct ApplicationState { - client: Client, - config: Arc, -} - -pub async fn run_app() { - tracing_subscriber::fmt::init(); - - let config = Arc::new(Config::load("config.yaml")); - let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await; - let client = Client::new(&aws_config); - - let app_state = ApplicationState { client, config }; - let addr = app_state.config.addr.parse::().unwrap(); - - let app = Router::new() - .route("/healthz", get(health)) - .route("/", any(handler)) - .route("/*path", any(handler)) - .layer(TraceLayer::new_for_http()) - .with_state(app_state); - - let listener = TcpListener::bind(addr).await.unwrap(); - tracing::info!("Listening on {}", addr); - axum::serve(listener, app).await.unwrap(); + pub client: Client, + pub config: Arc, } -async fn health() -> impl IntoResponse { +pub async fn health() -> impl IntoResponse { StatusCode::OK } -async fn handler( +pub async fn invoke_lambda( State(state): State, Query(query): Query, parts: Parts, diff --git a/src/main.rs b/src/main.rs index 4233464..5c57aa2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,34 @@ -use lambda_web_gateway::run_app; +use aws_config::BehaviorVersion; +use aws_sdk_lambda::Client; +use axum::{ + routing::{any, get}, + Router, +}; +use lambda_web_gateway::{health, invoke_lambda, ApplicationState, Config}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::net::TcpListener; +use tower_http::trace::TraceLayer; #[tokio::main] async fn main() { - run_app().await; + tracing_subscriber::fmt::init(); + + let state = { + let config = Arc::new(Config::load("config.yaml")); + let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await; + let client = Client::new(&aws_config); + ApplicationState { client, config } + }; + + let addr = state.config.addr.parse::().unwrap(); + let listener = TcpListener::bind(addr).await.unwrap(); + tracing::info!("Listening on {}", addr); + + let app = Router::new() + .route("/healthz", get(health)) + .route("/", any(invoke_lambda)) + .route("/*path", any(invoke_lambda)) + .layer(TraceLayer::new_for_http()) + .with_state(state); + axum::serve(listener, app).await.unwrap(); } From 0b2e52871206126c30ef6bb59f8e3257de1f0adc Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 06:49:28 +0000 Subject: [PATCH 18/31] chore: rename `whether_base64_encoded` to `whether_should_base64_encode` --- src/lib.rs | 8 ++++---- src/utils.rs | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 7335b74..889e226 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ use buffered::handle_buffered_response; use request::build_alb_request_body; use std::sync::Arc; use streaming::handle_streaming_response; -use utils::{handle_err, transform_body, whether_base64_encoded}; +use utils::{handle_err, transform_body, whether_should_base64_encode}; #[derive(Clone)] pub struct ApplicationState { @@ -43,12 +43,12 @@ pub async fn invoke_lambda( return StatusCode::UNAUTHORIZED.into_response(); } - let is_base64_encoded = whether_base64_encoded(&parts.headers); - let body = transform_body(is_base64_encoded, body); + let should_base64_encode = whether_should_base64_encode(&parts.headers); + let body = transform_body(should_base64_encode, body); let lambda_request_body = handle_err!( "Building lambda request", - build_alb_request_body(is_base64_encoded, query, parts, body) + build_alb_request_body(should_base64_encode, query, parts, body) ); macro_rules! call_lambda { diff --git a/src/utils.rs b/src/utils.rs index 5eb5842..87967ce 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -18,7 +18,7 @@ macro_rules! handle_err { } pub(crate) use handle_err; -pub(super) fn whether_base64_encoded(headers: &HeaderMap) -> bool { +pub(super) fn whether_should_base64_encode(headers: &HeaderMap) -> bool { let content_type = headers .get("content-type") .and_then(|v| v.to_str().ok()) @@ -33,8 +33,8 @@ pub(super) fn whether_base64_encoded(headers: &HeaderMap) -> bool { } } -pub(super) fn transform_body(is_base64_encoded: bool, body: Bytes) -> String { - if is_base64_encoded { +pub(super) fn transform_body(should_base64_encode: bool, body: Bytes) -> String { + if should_base64_encode { BASE64_STANDARD.encode(body) } else { String::from_utf8_lossy(&body).to_string() From a69fbb99193b51c35712580520c5bca7211c220b Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 07:31:02 +0000 Subject: [PATCH 19/31] fix: `handle_buffered_response` should be sync --- src/buffered.rs | 2 +- src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/buffered.rs b/src/buffered.rs index 7d05122..6057ade 100644 --- a/src/buffered.rs +++ b/src/buffered.rs @@ -4,7 +4,7 @@ use aws_sdk_lambda::operation::invoke::InvokeOutput; use axum::{body::Body, http::StatusCode, response::Response}; use base64::{prelude::BASE64_STANDARD, Engine}; -pub(super) async fn handle_buffered_response(resp: InvokeOutput) -> Response { +pub(super) fn handle_buffered_response(resp: InvokeOutput) -> Response { // Parse the InvokeOutput payload to extract the LambdaResponse let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); let lambda_response = handle_err!( diff --git a/src/lib.rs b/src/lib.rs index 889e226..94ae55d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,7 +67,7 @@ pub async fn invoke_lambda( } match state.config.lambda_invoke_mode { - LambdaInvokeMode::Buffered => handle_buffered_response(call_lambda!(invoke)).await, + LambdaInvokeMode::Buffered => handle_buffered_response(call_lambda!(invoke)), LambdaInvokeMode::ResponseStream => handle_streaming_response(call_lambda!(invoke_with_response_stream)).await, } } From 08e3aca73e3fcce25f61d1fa19dc64a963337722 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Fri, 3 Jan 2025 07:39:44 +0000 Subject: [PATCH 20/31] tests: add unit tests --- src/auth.rs | 57 +++++++++++++++++++++ src/buffered.rs | 57 +++++++++++++++++++++ src/lib.rs | 10 +++- src/request.rs | 44 ++++++++++++++++ src/tests.rs | 133 ------------------------------------------------ src/utils.rs | 33 ++++++++++++ 6 files changed, 200 insertions(+), 134 deletions(-) delete mode 100644 src/tests.rs diff --git a/src/auth.rs b/src/auth.rs index ea3e5aa..f0a2310 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -19,3 +19,60 @@ pub(super) fn is_authorized(headers: &HeaderMap, config: &Config) -> bool { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashSet; + + #[test] + fn test_open_auth() { + let headers = HeaderMap::new(); + let config = Config { + auth_mode: AuthMode::Open, + ..Default::default() + }; + assert!(is_authorized(&headers, &config)); + } + + #[test] + fn test_api_key_auth() { + let config = Config { + auth_mode: AuthMode::ApiKey, + api_keys: HashSet::from(["test".to_string()]), + ..Default::default() + }; + + let mut headers = HeaderMap::new(); + headers.insert("x-api-key", "test".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + headers.insert("x-api-key", "invalid".parse().unwrap()); + assert!(!is_authorized(&headers, &config)); + + headers.insert("authorization", "Bearer test".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + headers.insert("authorization", "Bearer invalid".parse().unwrap()); + assert!(!is_authorized(&headers, &config)); + } + + #[test] + fn test_multi_api_keys() { + let config = Config { + auth_mode: AuthMode::ApiKey, + api_keys: HashSet::from(["test1".to_string(), "test2".to_string()]), + ..Default::default() + }; + + let mut headers = HeaderMap::new(); + headers.insert("x-api-key", "test1".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + headers.insert("x-api-key", "test2".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + headers.insert("x-api-key", "invalid".parse().unwrap()); + assert!(!is_authorized(&headers, &config)); + } +} diff --git a/src/buffered.rs b/src/buffered.rs index 6057ade..0df8526 100644 --- a/src/buffered.rs +++ b/src/buffered.rs @@ -32,3 +32,60 @@ pub(super) fn handle_buffered_response(resp: InvokeOutput) -> Response { } handle_err!("Building response", resp_builder.body(Body::from(body))) } + +#[cfg(test)] +mod tests { + use super::*; + use aws_smithy_types::Blob; + use axum::http::HeaderMap; + + #[tokio::test] + async fn test_handle_buffered_response() { + let lambda_response = AlbTargetGroupResponse { + status_code: 200, + status_description: None, + is_base64_encoded: false, + headers: { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/plain".parse().unwrap()); + headers + }, + body: Some("Hello, world!".into()), + ..Default::default() + }; + let payload = serde_json::to_vec(&lambda_response).unwrap(); + let invoke_output = InvokeOutput::builder().payload(Blob::new(payload)).build(); + + let response = handle_buffered_response(invoke_output); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain"); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!(body, "Hello, world!"); + } + + #[tokio::test] + async fn test_handle_buffered_response_base64() { + let lambda_response = AlbTargetGroupResponse { + status_code: 200, + status_description: None, + is_base64_encoded: true, + headers: { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/plain".parse().unwrap()); + headers + }, + body: Some("SGVsbG8sIHdvcmxkIQ==".into()), + ..Default::default() + }; + let payload = serde_json::to_vec(&lambda_response).unwrap(); + let invoke_output = InvokeOutput::builder().payload(Blob::new(payload)).build(); + + let response = handle_buffered_response(invoke_output); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain"); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!(body, "Hello, world!"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 94ae55d..e4b4d83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,4 +73,12 @@ pub async fn invoke_lambda( } #[cfg(test)] -mod tests; +mod tests { + use super::*; + + #[tokio::test] + async fn test_health() { + let response = health().await.into_response(); + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/src/request.rs b/src/request.rs index 35675d1..c8385f4 100644 --- a/src/request.rs +++ b/src/request.rs @@ -25,3 +25,47 @@ pub(super) fn build_alb_request_body( multi_value_query_string_parameters: Default::default(), }) } + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::{request::Builder, Method}; + use base64::{prelude::BASE64_STANDARD, Engine}; + use std::collections::HashMap; + + // TODO:update aws_lambda_events to make these tests pass + // https://github.com/awslabs/aws-lambda-rust-runtime/issues/954 + + #[test] + fn test_alb_body() { + let (parts, body) = Builder::new() + .method(Method::GET) + .uri("https://example.com/?k=v") + .header("key", "value") + .body("Hello, world!") + .unwrap() + .into_parts(); + let query = HashMap::from([("k".to_string(), "v".to_string())]).into(); + + let expected = "{\"httpMethod\":\"GET\",\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"multiValueQueryStringParameters\":{},\"headers\":{\"key\":\"value\"},\"multiValueHeaders\":{},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}},\"isBase64Encoded\":false,\"body\":\"Hello, world!\"}"; + assert_eq!( + build_alb_request_body(false, query, parts, body.into()).unwrap(), + expected + ); + } + + #[test] + fn test_alb_body_base64() { + let (parts, body) = Builder::new() + .method(Method::GET) + .uri("https://example.com/?k=v") + .header("key", "value") + .body(BASE64_STANDARD.encode("Hello, world!")) + .unwrap() + .into_parts(); + let query = HashMap::from([("k".to_string(), "v".to_string())]).into(); + + let expected = "{\"httpMethod\":\"GET\",\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"multiValueQueryStringParameters\":{},\"headers\":{\"key\":\"value\"},\"multiValueHeaders\":{},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}},\"isBase64Encoded\":true,\"body\":\"SGVsbG8sIHdvcmxkIQ==\"}"; + assert_eq!(build_alb_request_body(true, query, parts, body).unwrap(), expected); + } +} diff --git a/src/tests.rs b/src/tests.rs deleted file mode 100644 index ca070f0..0000000 --- a/src/tests.rs +++ /dev/null @@ -1,133 +0,0 @@ -use super::*; -// use axum::http::StatusCode; -// use aws_smithy_types::Blob; -// use std::collections::HashMap; -// use aws_sdk_lambda::types::InvokeWithResponseStreamResponseEvent; -// use aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput; -// use aws_sdk_lambda::primitives::event_stream::EventReceiver; -// use aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamError; - -#[tokio::test] -async fn test_health() { - let response = health().await.into_response(); - assert_eq!(response.status(), StatusCode::OK); -} - -#[tokio::test] -async fn test_to_string_map() { - let mut headers = HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse().unwrap()); - headers.insert("X-Custom-Header", "test-value".parse().unwrap()); - - let result = to_string_map(&headers); - - assert_eq!(result.len(), 2); - assert_eq!(result.get("content-type"), Some(&"application/json".to_string())); - assert_eq!(result.get("x-custom-header"), Some(&"test-value".to_string())); -} - -#[tokio::test] -async fn test_handle_buffered_response() { - let lambda_response = LambdaResponse { - status_code: 200, - status_description: Some("OK".to_string()), - is_base64_encoded: Some(false), - headers: Some(HashMap::from([("Content-Type".to_string(), "text/plain".to_string())])), - body: "Hello, World!".to_string(), - }; - - let payload = serde_json::to_vec(&lambda_response).unwrap(); - let invoke_output = aws_sdk_lambda::operation::invoke::InvokeOutput::builder() - .payload(Blob::new(payload)) - .status_code(200) - .build(); - - let response = handle_buffered_response(invoke_output).await; - - assert_eq!(response.status(), StatusCode::OK); - assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain"); - let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); - assert_eq!(body, "Hello, World!"); -} - -// #[tokio::test] -// async fn test_detect_metadata() { -// let payload = r#"{"statusCode": 200, "headers": {"Content-Type": "text/plain"}, "body": "Hello"}"#; -// let full_payload = payload.as_bytes().to_vec(); -// let chunk = InvokeWithResponseStreamResponseEvent::PayloadChunk( -// aws_sdk_lambda::types::InvokeResponseStreamUpdate::builder() -// .payload(Blob::new(full_payload.clone())) -// .build(), -// ); - -// let event_receiver = EventReceiver { -// inner: vec![chunk], -// }; - -// let mut resp = InvokeWithResponseStreamOutput::builder() -// .event_stream(event_receiver) -// .build() -// .unwrap(); - -// // Type annotation to help the compiler -// let resp: InvokeWithResponseStreamOutput = resp; - -// let (has_metadata, first_chunk) = detect_metadata(&mut resp).await; - -// assert!(has_metadata); -// assert_eq!(first_chunk.unwrap(), full_payload); -// } - -// #[tokio::test] -// async fn test_collect_metadata() { -// let payload = r#"{"statusCode": 200, "headers": {"Content-Type": "text/plain"}, "body": "Hello"}"#; -// let null_padding = vec![0u8; 8]; -// let remaining_data = b"Remaining data"; - -// let mut full_payload = payload.as_bytes().to_vec(); -// full_payload.extend_from_slice(&null_padding); -// full_payload.extend_from_slice(remaining_data); - -// let chunk = InvokeWithResponseStreamResponseEvent::PayloadChunk( -// aws_sdk_lambda::types::InvokeResponseStreamUpdate::builder() -// .payload(Blob::new(full_payload)) -// .build(), -// ); - -// let event_receiver =EventReceiver { -// inner: vec![chunk], -// }; - -// let mut resp = InvokeWithResponseStreamOutput::builder() -// .event_stream(event_receiver) -// .build() -// .unwrap(); - -// let mut metadata_buffer = Vec::new(); -// let (metadata_prelude, remaining) = collect_metadata(&mut resp, &mut metadata_buffer).await; - -// assert!(metadata_prelude.is_some()); -// let prelude = metadata_prelude.unwrap(); -// assert_eq!(prelude.status_code, StatusCode::OK); -// assert_eq!(prelude.headers.get("content-type").unwrap(), "text/plain"); -// assert_eq!(remaining, remaining_data); -// } - -// #[tokio::test] -// async fn test_process_buffer() { -// let payload = r#"{"statusCode": 200, "headers": {"Content-Type": "text/plain"}, "body": "Hello"}"#; -// let null_padding = vec![0u8; 8]; -// let remaining_data = b"Remaining data"; -// -// let mut buffer = payload.as_bytes().to_vec(); -// buffer.extend_from_slice(&null_padding); -// buffer.extend_from_slice(remaining_data); -// -// let (metadata_prelude, remaining) = process_buffer(&buffer); -// -// assert!(metadata_prelude.is_some()); -// let prelude = metadata_prelude.unwrap(); -// assert_eq!(prelude.status_code, StatusCode::OK); -// assert_eq!(prelude.headers.get("content-type").unwrap(), "text/plain"); -// assert_eq!(remaining, remaining_data); -// } diff --git a/src/utils.rs b/src/utils.rs index 87967ce..0355c22 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -40,3 +40,36 @@ pub(super) fn transform_body(should_base64_encode: bool, body: Bytes) -> String String::from_utf8_lossy(&body).to_string() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_whether_base64_encoded() { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "application/xml".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "application/javascript".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "text/html".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "image/png".parse().unwrap()); + assert!(whether_should_base64_encode(&headers)); + } + + #[test] + fn test_transform_body() { + let body = Bytes::from("Hello, world!"); + assert_eq!(transform_body(false, body.clone()), "Hello, world!"); + + let base64_body = Bytes::from(BASE64_STANDARD.encode(&body)); + assert_eq!(transform_body(true, base64_body), body); + } +} From f22ee044da52375939016956ea71480e1b854488 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 03:43:26 +0000 Subject: [PATCH 21/31] chore: optimize code for streaming mod, prevent unnecessary clone --- src/streaming.rs | 142 ++++++++++++++++++++++++----------------------- 1 file changed, 74 insertions(+), 68 deletions(-) diff --git a/src/streaming.rs b/src/streaming.rs index 0178d8d..97cc645 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -16,14 +16,15 @@ use tokio::sync::mpsc; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use InvokeWithResponseStreamResponseEvent::*; +// TODO: contribute to `lambda_runtime` crate to make this struct derive Deserialize #[derive(Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct MetadataPrelude { - #[serde(with = "http_serde::status_code")] /// The HTTP status code. + #[serde(with = "http_serde::status_code")] pub status_code: StatusCode, - #[serde(with = "http_serde::header_map")] /// The HTTP headers. + #[serde(with = "http_serde::header_map")] pub headers: HeaderMap, /// The HTTP cookies. pub cookies: Vec, @@ -31,32 +32,37 @@ struct MetadataPrelude { pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { let (tx, rx) = mpsc::channel(1); - let mut metadata_buffer = Vec::new(); - let mut metadata_prelude: Option = None; - let mut remaining_data = Vec::new(); - - // Step 1: Detect if metadata exists and get the first chunk - let (has_metadata, first_chunk) = detect_metadata(&mut resp).await; - - // Step 2: Process the first chunk - if let Some(chunk) = first_chunk { - if has_metadata { - metadata_buffer.extend_from_slice(&chunk); - (metadata_prelude, remaining_data) = collect_metadata(&mut resp, &mut metadata_buffer).await; + + let Some(PayloadChunk(InvokeResponseStreamUpdate { + payload: Some(first_chunk), + .. + })) = handle_err!("Receiving response stream", resp.event_stream.recv().await) + else { + // TODO: correct the response if there is no chunk + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap(); + }; + let mut buffer = first_chunk.into_inner(); + + // Detect and collect metadata prelude + let (metadata_prelude, buffer) = if detect_metadata(&buffer) { + if let Some((metadata_prelude, rest)) = collect_metadata(&mut resp, &mut buffer).await { + (Some(metadata_prelude), rest) } else { - // No metadata prelude, treat first chunk as payload - remaining_data = chunk; + (None, buffer) } - } + } else { + (None, buffer) + }; // Spawn task to handle remaining stream tokio::spawn(async move { // Send remaining data after metadata first - if !remaining_data.is_empty() { - let stream_update = InvokeResponseStreamUpdate::builder() - .payload(Blob::new(remaining_data)) - .build(); - let _ = tx.send(PayloadChunk(stream_update)).await; + if !buffer.is_empty() { + let stream_update = InvokeResponseStreamUpdate::builder().payload(Blob::new(buffer)).build(); + tx.send(PayloadChunk(stream_update)).await.ok(); } while let Ok(Some(event)) = resp.event_stream.recv().await { @@ -85,10 +91,10 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream if let Some(metadata_prelude) = metadata_prelude { resp_builder = resp_builder.status(metadata_prelude.status_code); - for (k, v) in metadata_prelude.headers.iter() { - if k != "content-length" { - resp_builder = resp_builder.header(k, v); - } + { + let headers = resp_builder.headers_mut().unwrap(); + *headers = metadata_prelude.headers; + headers.remove("content-length"); } for cookie in &metadata_prelude.cookies { @@ -103,64 +109,64 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream handle_err!("Building response", resp_builder.body(Body::from_stream(stream))) } -async fn detect_metadata(resp: &mut InvokeWithResponseStreamOutput) -> (bool, Option>) { - if let Ok(Some(PayloadChunk(chunk))) = resp.event_stream.recv().await { - if let Some(data) = chunk.payload { - let bytes = data.into_inner(); - let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; - return (has_metadata, Some(bytes)); - } - } - (false, None) +fn detect_metadata(bytes: &[u8]) -> bool { + bytes.get(0) == Some(&b'{') } +/// Return metadata prelude and remaining data if metadata is complete. +/// Return [`None`] if the stream is exhausted without complete metadata. async fn collect_metadata( resp: &mut InvokeWithResponseStreamOutput, metadata_buffer: &mut Vec, -) -> (Option, Vec) { - let mut metadata_prelude = None; - let mut remaining_data = Vec::new(); - +) -> Option<(MetadataPrelude, Vec)> { // Process the metadata_buffer first - let (prelude, remaining) = process_buffer(metadata_buffer); - if let Some(p) = prelude { - return (Some(p), remaining); + if let Some((prelude, remaining)) = try_parse_metadata(metadata_buffer) { + return Some((prelude, remaining.into())); } // If metadata is not complete, continue processing the stream - while let Ok(Some(event)) = resp.event_stream.recv().await { - if let PayloadChunk(chunk) = event { - if let Some(data) = chunk.payload { - let bytes = data.into_inner(); - metadata_buffer.extend_from_slice(&bytes); - let (prelude, remaining) = process_buffer(metadata_buffer); - if let Some(p) = prelude { - metadata_prelude = Some(p); - remaining_data = remaining; - break; - } - } + // TODO: handle error + while let Ok(Some(PayloadChunk(InvokeResponseStreamUpdate { + payload: Some(data), .. + }))) = resp.event_stream.recv().await + { + let bytes = data.into_inner(); + metadata_buffer.extend_from_slice(&bytes); + if let Some((prelude, remaining)) = try_parse_metadata(metadata_buffer) { + return Some((prelude, remaining.into())); } } - (metadata_prelude, remaining_data) + None } -fn process_buffer(buffer: &[u8]) -> (Option, Vec) { +/// If metadata prelude is found, return the metadata prelude and the remaining data. +fn try_parse_metadata(buffer: &[u8]) -> Option<(MetadataPrelude, &[u8])> { let mut null_count = 0; for (i, &byte) in buffer.iter().enumerate() { - if byte == 0 { - null_count += 1; - if null_count == 8 { - let metadata_str = String::from_utf8_lossy(&buffer[..i]); - let metadata_prelude = serde_json::from_str(&metadata_str).unwrap_or_default(); - tracing::debug!(metadata_prelude=?metadata_prelude); - // Save remaining data after metadata - let remaining_data = buffer[i + 1..].to_vec(); - return (Some(metadata_prelude), remaining_data); - } - } else { + if byte != 0 { null_count = 0; + continue; } + null_count += 1; + if null_count == 8 { + // now we have 8 continuous null bytes + let metadata_str = String::from_utf8_lossy(&buffer[..i - 7]); + // TODO: handle invalid metadata prelude + let metadata_prelude = serde_json::from_str(&metadata_str).unwrap_or_default(); + tracing::debug!(metadata_prelude=?metadata_prelude); + return Some((metadata_prelude, &buffer[i + 1..])); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_metadata() { + assert!(detect_metadata(b"{\"statusCode\":200,\"headers\":{}}")); + assert!(!detect_metadata(b"Hello, world!")); } - (None, Vec::new()) } From 1b589848e2104e67c00ef78651ac078a8036aace Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 03:52:29 +0000 Subject: [PATCH 22/31] chore: optimize stream handling --- src/streaming.rs | 69 +++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/src/streaming.rs b/src/streaming.rs index 97cc645..78b7767 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -3,7 +3,6 @@ use aws_sdk_lambda::{ operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, types::{InvokeResponseStreamUpdate, InvokeWithResponseStreamResponseEvent}, }; -use aws_smithy_types::Blob; use axum::{ body::Body, http::{HeaderMap, StatusCode}, @@ -31,8 +30,6 @@ struct MetadataPrelude { } pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { - let (tx, rx) = mpsc::channel(1); - let Some(PayloadChunk(InvokeResponseStreamUpdate { payload: Some(first_chunk), .. @@ -57,35 +54,6 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream (None, buffer) }; - // Spawn task to handle remaining stream - tokio::spawn(async move { - // Send remaining data after metadata first - if !buffer.is_empty() { - let stream_update = InvokeResponseStreamUpdate::builder().payload(Blob::new(buffer)).build(); - tx.send(PayloadChunk(stream_update)).await.ok(); - } - - while let Ok(Some(event)) = resp.event_stream.recv().await { - tx.send(event).await.ok(); - } - }); - - let stream = ReceiverStream::new(rx).map(|event| match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload { - let bytes = data.into_inner(); - Ok::<_, Infallible>(Bytes::from(bytes)) - } else { - Ok(Bytes::default()) - } - } - InvokeComplete(_) => Ok(Bytes::default()), - _ => { - tracing::warn!("Unhandled event type: {:?}", event); - Ok(Bytes::default()) - } - }); - let mut resp_builder = Response::builder(); if let Some(metadata_prelude) = metadata_prelude { @@ -106,7 +74,42 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream resp_builder = resp_builder.header("content-type", "application/octet-stream"); } - handle_err!("Building response", resp_builder.body(Body::from_stream(stream))) + // Spawn task to handle remaining stream + let (tx, rx) = mpsc::channel(1); + tokio::spawn(async move { + // Send remaining data after metadata first + if !buffer.is_empty() { + tx.send(buffer).await.ok(); + } + + // TODO: handle error + while let Ok(Some(event)) = resp.event_stream.recv().await { + match event { + PayloadChunk(chunk) => { + if let Some(data) = chunk.payload { + tx.send(data.into_inner()).await.ok(); + } + // else, no data in the chunk, just ignore + } + InvokeComplete(_) => { + break; + } + _ => { + tracing::warn!("Unhandled event type: {:?}", event); + } + } + } + }); + + handle_err!( + "Building response", + resp_builder.body(Body::from_stream(ReceiverStream::new(rx).map(|bytes| Ok::< + _, + Infallible, + >( + Bytes::from(bytes) + )))) + ) } fn detect_metadata(bytes: &[u8]) -> bool { From 515542a5078492b564b8216a9e7183b547052610 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 03:57:23 +0000 Subject: [PATCH 23/31] tests: add unit tests for `try_parse_metadata` --- src/streaming.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/streaming.rs b/src/streaming.rs index 78b7767..3a4dcf8 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -172,4 +172,18 @@ mod tests { assert!(detect_metadata(b"{\"statusCode\":200,\"headers\":{}}")); assert!(!detect_metadata(b"Hello, world!")); } + + #[test] + fn test_try_parse_metadata() { + // incomplete + assert!(try_parse_metadata(b"{\"statusCod").is_none()); + assert!(try_parse_metadata(b"{\"statusCode\":200,\"headers\":{}}\0\0\0").is_none()); + + // complete + let (metadata_prelude, remaining) = + try_parse_metadata(b"{\"statusCode\":200,\"headers\":{}}\0\0\0\0\0\0\0\0Hello, world!").unwrap(); + assert_eq!(metadata_prelude.status_code, StatusCode::OK); + assert_eq!(metadata_prelude.headers.len(), 0); + assert_eq!(remaining, b"Hello, world!"); + } } From d5c1fcf6e88f797ea49cf5de2cd8c18b20aa617f Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 05:44:24 +0000 Subject: [PATCH 24/31] chore: refactor streaming to simplify code --- src/streaming.rs | 80 ++++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 51 deletions(-) diff --git a/src/streaming.rs b/src/streaming.rs index 3a4dcf8..4fe7cb5 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -30,42 +30,45 @@ struct MetadataPrelude { } pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { - let Some(PayloadChunk(InvokeResponseStreamUpdate { - payload: Some(first_chunk), - .. - })) = handle_err!("Receiving response stream", resp.event_stream.recv().await) - else { - // TODO: correct the response if there is no chunk - return Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::empty()) - .unwrap(); - }; - let mut buffer = first_chunk.into_inner(); - - // Detect and collect metadata prelude - let (metadata_prelude, buffer) = if detect_metadata(&buffer) { - if let Some((metadata_prelude, rest)) = collect_metadata(&mut resp, &mut buffer).await { - (Some(metadata_prelude), rest) - } else { - (None, buffer) + // collect metadata + let (metadata, buffer) = { + let mut buffer = vec![]; + loop { + let next = handle_err!("Receiving response stream", resp.event_stream.recv().await); + if let Some(PayloadChunk(InvokeResponseStreamUpdate { + payload: Some(data), .. + })) = next + { + buffer.extend_from_slice(&data.into_inner()); + + // actually this is only required for the first chunk + // but this is cheap, so we call it in the loop to simplify the flow + if !detect_metadata(&buffer) { + break (None, buffer); + } + + if let Some((prelude, remaining)) = try_parse_metadata(&mut buffer) { + break (Some(prelude), remaining.into()); + } + } else { + // no more chunks + break (None, buffer); + } } - } else { - (None, buffer) }; let mut resp_builder = Response::builder(); - if let Some(metadata_prelude) = metadata_prelude { - resp_builder = resp_builder.status(metadata_prelude.status_code); + if let Some(metadata) = metadata { + resp_builder = resp_builder.status(metadata.status_code); { let headers = resp_builder.headers_mut().unwrap(); - *headers = metadata_prelude.headers; + *headers = metadata.headers; headers.remove("content-length"); } - for cookie in &metadata_prelude.cookies { + for cookie in &metadata.cookies { resp_builder = resp_builder.header("set-cookie", cookie); } } else { @@ -112,36 +115,11 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream ) } +#[inline] fn detect_metadata(bytes: &[u8]) -> bool { bytes.get(0) == Some(&b'{') } -/// Return metadata prelude and remaining data if metadata is complete. -/// Return [`None`] if the stream is exhausted without complete metadata. -async fn collect_metadata( - resp: &mut InvokeWithResponseStreamOutput, - metadata_buffer: &mut Vec, -) -> Option<(MetadataPrelude, Vec)> { - // Process the metadata_buffer first - if let Some((prelude, remaining)) = try_parse_metadata(metadata_buffer) { - return Some((prelude, remaining.into())); - } - - // If metadata is not complete, continue processing the stream - // TODO: handle error - while let Ok(Some(PayloadChunk(InvokeResponseStreamUpdate { - payload: Some(data), .. - }))) = resp.event_stream.recv().await - { - let bytes = data.into_inner(); - metadata_buffer.extend_from_slice(&bytes); - if let Some((prelude, remaining)) = try_parse_metadata(metadata_buffer) { - return Some((prelude, remaining.into())); - } - } - None -} - /// If metadata prelude is found, return the metadata prelude and the remaining data. fn try_parse_metadata(buffer: &[u8]) -> Option<(MetadataPrelude, &[u8])> { let mut null_count = 0; From 34316587b0a1632ac38c441cb5a171ada4e4a672 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 05:49:48 +0000 Subject: [PATCH 25/31] chore: extract fn `create_response_builder` for streaming --- src/streaming.rs | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/streaming.rs b/src/streaming.rs index 4fe7cb5..b843695 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -5,7 +5,7 @@ use aws_sdk_lambda::{ }; use axum::{ body::Body, - http::{HeaderMap, StatusCode}, + http::{response::Builder, HeaderMap, StatusCode}, response::Response, }; use bytes::Bytes; @@ -57,25 +57,7 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream } }; - let mut resp_builder = Response::builder(); - - if let Some(metadata) = metadata { - resp_builder = resp_builder.status(metadata.status_code); - - { - let headers = resp_builder.headers_mut().unwrap(); - *headers = metadata.headers; - headers.remove("content-length"); - } - - for cookie in &metadata.cookies { - resp_builder = resp_builder.header("set-cookie", cookie); - } - } else { - // Default response if no metadata - resp_builder = resp_builder.status(StatusCode::OK); - resp_builder = resp_builder.header("content-type", "application/octet-stream"); - } + let resp_builder = create_response_builder(metadata); // Spawn task to handle remaining stream let (tx, rx) = mpsc::channel(1); @@ -141,6 +123,30 @@ fn try_parse_metadata(buffer: &[u8]) -> Option<(MetadataPrelude, &[u8])> { None } +fn create_response_builder(metadata: Option) -> Builder { + if let Some(metadata) = metadata { + let mut builder = Response::builder().status(metadata.status_code); + + // apply all headers except content-length + { + let headers = builder.headers_mut().unwrap(); + *headers = metadata.headers; + headers.remove("content-length"); + } + + for cookie in &metadata.cookies { + builder = builder.header("set-cookie", cookie); + } + + builder + } else { + // Default response if no metadata + Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/octet-stream") + } +} + #[cfg(test)] mod tests { use super::*; From 9a46da05a5d40a94b132d58db60b57bead1e0db6 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 05:54:18 +0000 Subject: [PATCH 26/31] tests: add unit tests for `create_response_builder` --- src/streaming.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/streaming.rs b/src/streaming.rs index b843695..21a0169 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -170,4 +170,36 @@ mod tests { assert_eq!(metadata_prelude.headers.len(), 0); assert_eq!(remaining, b"Hello, world!"); } + + #[test] + fn test_create_response_builder() { + let metadata = MetadataPrelude { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + headers: { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "text/plain".parse().unwrap()); + headers.insert("content-length", "0".parse().unwrap()); + headers + }, + cookies: vec!["cookie1".to_string(), "cookie2".to_string()], + }; + let builder = create_response_builder(Some(metadata)); + let response = builder.body(Body::empty()).unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(response.headers().get("content-type").unwrap(), "text/plain"); + assert_eq!(response.headers().get("content-length"), None); + assert_eq!( + response.headers().get_all("set-cookie").iter().collect::>(), + &["cookie1", "cookie2"] + ); + + let builder = create_response_builder(None); + let response = builder.body(Body::empty()).unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/octet-stream" + ); + assert_eq!(response.headers().get("content-length"), None); + } } From 4ddea29abc6626e7b6dd1458a1683a73586b22c5 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 6 Jan 2025 06:06:20 +0000 Subject: [PATCH 27/31] chore: better streaming error handling --- src/streaming.rs | 53 ++++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/src/streaming.rs b/src/streaming.rs index 21a0169..43cd914 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -10,7 +10,6 @@ use axum::{ }; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use std::convert::Infallible; use tokio::sync::mpsc; use tokio_stream::{wrappers::ReceiverStream, StreamExt}; use InvokeWithResponseStreamResponseEvent::*; @@ -57,30 +56,43 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream } }; - let resp_builder = create_response_builder(metadata); + let builder = create_response_builder(metadata); // Spawn task to handle remaining stream let (tx, rx) = mpsc::channel(1); tokio::spawn(async move { // Send remaining data after metadata first if !buffer.is_empty() { - tx.send(buffer).await.ok(); + tx.send(Ok(buffer)).await.ok(); } - // TODO: handle error - while let Ok(Some(event)) = resp.event_stream.recv().await { - match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload { - tx.send(data.into_inner()).await.ok(); - } - // else, no data in the chunk, just ignore - } - InvokeComplete(_) => { - break; + loop { + match resp.event_stream.recv().await { + Err(e) => { + tx.send(Err(e)).await.ok(); } - _ => { - tracing::warn!("Unhandled event type: {:?}", event); + Ok(e) => { + if let Some(update) = e { + match update { + PayloadChunk(chunk) => { + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); + if !bytes.is_empty() { + tx.send(Ok(bytes)).await.ok(); + } + } + // else, no data in the chunk, just ignore + } + InvokeComplete(_) => { + break; + } + _ => { + tracing::warn!("Unhandled event type: {:?}", update); + } + } + } else { + break; + } } } } @@ -88,12 +100,9 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream handle_err!( "Building response", - resp_builder.body(Body::from_stream(ReceiverStream::new(rx).map(|bytes| Ok::< - _, - Infallible, - >( - Bytes::from(bytes) - )))) + builder.body(Body::from_stream( + ReceiverStream::new(rx).map(|res| res.map(|bytes| Bytes::from(bytes))) + )) ) } From 7b403b396c8aa9f08246cee85b87e403ea22ede2 Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Tue, 7 Jan 2025 12:02:26 +0000 Subject: [PATCH 28/31] tests: fix unit tests for auth --- src/auth.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/auth.rs b/src/auth.rs index f0a2310..1e398ff 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -47,12 +47,15 @@ mod tests { headers.insert("x-api-key", "test".parse().unwrap()); assert!(is_authorized(&headers, &config)); + let mut headers = HeaderMap::new(); headers.insert("x-api-key", "invalid".parse().unwrap()); assert!(!is_authorized(&headers, &config)); + let mut headers = HeaderMap::new(); headers.insert("authorization", "Bearer test".parse().unwrap()); assert!(is_authorized(&headers, &config)); + let mut headers = HeaderMap::new(); headers.insert("authorization", "Bearer invalid".parse().unwrap()); assert!(!is_authorized(&headers, &config)); } From d0c4a5fc1956dfbcb8035249effba2089ca8ed2e Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Tue, 7 Jan 2025 12:04:06 +0000 Subject: [PATCH 29/31] tests: fix unit tests for transform_body --- src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.rs b/src/utils.rs index 0355c22..b17cf31 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -70,6 +70,6 @@ mod tests { assert_eq!(transform_body(false, body.clone()), "Hello, world!"); let base64_body = Bytes::from(BASE64_STANDARD.encode(&body)); - assert_eq!(transform_body(true, base64_body), body); + assert_eq!(transform_body(true, body), base64_body); } } From 493a8b791845f0b2fd0c372f72110ae2bc77f25f Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 13 Jan 2025 05:54:29 +0000 Subject: [PATCH 30/31] fix: fix queryStringParameters in build_alb_request_body --- src/request.rs | 72 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/src/request.rs b/src/request.rs index c8385f4..51c3cf1 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,8 +1,6 @@ -use aws_lambda_events::{ - alb::{AlbTargetGroupRequest, AlbTargetGroupRequestContext, ElbContext}, - query_map::QueryMap, -}; -use axum::http::request::Parts; +use aws_lambda_events::query_map::QueryMap; +use axum::http::{request::Parts, HeaderMap}; +use std::collections::HashMap; pub(super) fn build_alb_request_body( is_base64_encoded: bool, @@ -10,20 +8,53 @@ pub(super) fn build_alb_request_body( parts: Parts, body: String, ) -> Result { - serde_json::to_string(&AlbTargetGroupRequest { - http_method: parts.method, - headers: parts.headers, - path: parts.uri.path().to_string().into(), - query_string_parameters, - body: body.into(), - is_base64_encoded, - request_context: AlbTargetGroupRequestContext { - elb: ElbContext { target_group_arn: None }, + Ok(serde_json::json!({ + "httpMethod": parts.method.to_string(), + "path": parts.uri.path(), + "queryStringParameters": query_map_to_hash_map(query_string_parameters), + "multiValueQueryStringParameters": {}, + "headers": header_map_to_hash_map(parts.headers), + "multiValueHeaders": {}, + "requestContext": { + "elb": { + "targetGroupArn": Option::::None + } }, - // TODO: support multi-value-header mode? - multi_value_headers: Default::default(), - multi_value_query_string_parameters: Default::default(), + "isBase64Encoded": is_base64_encoded, + "body": body, }) + .to_string()) + // serde_json::to_string(&AlbTargetGroupRequest { + // http_method: parts.method, + // headers: parts.headers, + // path: parts.uri.path().to_string().into(), + // query_string_parameters, + // body: body.into(), + // is_base64_encoded, + // request_context: AlbTargetGroupRequestContext { + // elb: ElbContext { target_group_arn: None }, + // }, + // // TODO: support multi-value-header mode? + // multi_value_headers: Default::default(), + // multi_value_query_string_parameters: Default::default(), + // }) +} + +// TODO: remove this after https://github.com/awslabs/aws-lambda-rust-runtime/pull/955 is merged +fn query_map_to_hash_map(map: QueryMap) -> HashMap { + map.iter() + .map(|(k, _)| { + let values = map.all(k).unwrap(); + (k.to_string(), values.iter().last().unwrap().to_string()) + }) + .collect() +} + +// TODO: remove this after https://github.com/awslabs/aws-lambda-rust-runtime/pull/955 is merged +fn header_map_to_hash_map(map: HeaderMap) -> HashMap { + map.iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string())) + .collect() } #[cfg(test)] @@ -33,9 +64,6 @@ mod tests { use base64::{prelude::BASE64_STANDARD, Engine}; use std::collections::HashMap; - // TODO:update aws_lambda_events to make these tests pass - // https://github.com/awslabs/aws-lambda-rust-runtime/issues/954 - #[test] fn test_alb_body() { let (parts, body) = Builder::new() @@ -47,7 +75,7 @@ mod tests { .into_parts(); let query = HashMap::from([("k".to_string(), "v".to_string())]).into(); - let expected = "{\"httpMethod\":\"GET\",\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"multiValueQueryStringParameters\":{},\"headers\":{\"key\":\"value\"},\"multiValueHeaders\":{},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}},\"isBase64Encoded\":false,\"body\":\"Hello, world!\"}"; + let expected = "{\"body\":\"Hello, world!\",\"headers\":{\"key\":\"value\"},\"httpMethod\":\"GET\",\"isBase64Encoded\":false,\"multiValueHeaders\":{},\"multiValueQueryStringParameters\":{},\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}}}"; assert_eq!( build_alb_request_body(false, query, parts, body.into()).unwrap(), expected @@ -65,7 +93,7 @@ mod tests { .into_parts(); let query = HashMap::from([("k".to_string(), "v".to_string())]).into(); - let expected = "{\"httpMethod\":\"GET\",\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"multiValueQueryStringParameters\":{},\"headers\":{\"key\":\"value\"},\"multiValueHeaders\":{},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}},\"isBase64Encoded\":true,\"body\":\"SGVsbG8sIHdvcmxkIQ==\"}"; + let expected = "{\"body\":\"SGVsbG8sIHdvcmxkIQ==\",\"headers\":{\"key\":\"value\"},\"httpMethod\":\"GET\",\"isBase64Encoded\":true,\"multiValueHeaders\":{},\"multiValueQueryStringParameters\":{},\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}}}"; assert_eq!(build_alb_request_body(true, query, parts, body).unwrap(), expected); } } From 297032e504827fd1d946850c9dea4b54a9f6b7fc Mon Sep 17 00:00:00 2001 From: DiscreteTom Date: Mon, 13 Jan 2025 05:57:33 +0000 Subject: [PATCH 31/31] chore: optimize code following clippy --- src/streaming.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/streaming.rs b/src/streaming.rs index 43cd914..f157d84 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -46,7 +46,7 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream break (None, buffer); } - if let Some((prelude, remaining)) = try_parse_metadata(&mut buffer) { + if let Some((prelude, remaining)) = try_parse_metadata(&buffer) { break (Some(prelude), remaining.into()); } } else { @@ -101,14 +101,14 @@ pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStream handle_err!( "Building response", builder.body(Body::from_stream( - ReceiverStream::new(rx).map(|res| res.map(|bytes| Bytes::from(bytes))) + ReceiverStream::new(rx).map(|res| res.map(Bytes::from)) )) ) } #[inline] fn detect_metadata(bytes: &[u8]) -> bool { - bytes.get(0) == Some(&b'{') + bytes.first() == Some(&b'{') } /// If metadata prelude is found, return the metadata prelude and the remaining data.