diff --git a/Cargo.lock b/Cargo.lock index 313fd47..cf49f68 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -17,6 +17,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +32,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.15" @@ -455,6 +476,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws_lambda_events" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ddb91585253ccc85be3f2e0d5635529efdeadaf8a1da3230b433d3bbe43648" +dependencies = [ + "base64 0.22.0", + "bytes", + "chrono", + "flate2", + "http 1.1.0", + "http-body 1.0.0", + "http-serde", + "query_map", + "serde", + "serde_dynamo", + "serde_json", + "serde_with", +] + [[package]] name = "axum" version = "0.7.5" @@ -520,7 +561,7 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.7.2", "object", "rustc-demangle", ] @@ -591,11 +632,20 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "bytes" version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +dependencies = [ + "serde", +] [[package]] name = "bytes-utils" @@ -633,6 +683,19 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "serde", + "windows-targets 0.52.6", +] + [[package]] name = "clang-sys" version = "1.7.0" @@ -743,6 +806,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.60", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.60", +] + [[package]] name = "deranged" version = "0.3.11" @@ -750,6 +848,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", + "serde", ] [[package]] @@ -797,6 +896,16 @@ version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide 0.8.2", +] + [[package]] name = "fnv" version = "1.0.7" @@ -952,13 +1061,19 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap", + "indexmap 2.2.6", "slab", "tokio", "tokio-util", "tracing", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.14.3" @@ -1154,6 +1269,35 @@ dependencies = [ "tokio", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1164,6 +1308,17 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + [[package]] name = "indexmap" version = "2.2.6" @@ -1171,7 +1326,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.3", + "serde", ] [[package]] @@ -1210,6 +1366,16 @@ dependencies = [ "libc", ] +[[package]] +name = "js-sys" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + [[package]] name = "lambda-web-gateway" version = "0.1.0" @@ -1217,6 +1383,7 @@ dependencies = [ "aws-config", "aws-sdk-lambda", "aws-smithy-types", + "aws_lambda_events", "axum", "base64 0.22.0", "bytes", @@ -1229,6 +1396,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "serial_test", "tempfile", "tokio", "tokio-stream", @@ -1321,6 +1489,15 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +dependencies = [ + "adler2", +] + [[package]] name = "mio" version = "1.0.2" @@ -1508,6 +1685,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "query_map" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eab6b8b1074ef3359a863758dae650c7c0c6027927a085b7af911c8e0bf3a15" +dependencies = [ + "form_urlencoded", + "serde", + "serde_derive", +] + [[package]] name = "quote" version = "1.0.36" @@ -1698,6 +1886,15 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" +[[package]] +name = "scc" +version = "2.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8d25269dd3a12467afe2e510f69fb0b46b698e5afb296b59f2145259deaf8e8" +dependencies = [ + "sdd", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1723,6 +1920,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sdd" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49c1eeaf4b6a87c7479688c6d52b9f1153cedd3c489300564f932b065c6eab95" + [[package]] name = "security-framework" version = "2.3.1" @@ -1772,6 +1975,16 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "serde_dynamo" +version = "4.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e36c1b1792cfd9de29eb373ee6a4b74650369c096f55db7198ceb7b8921d1f7f" +dependencies = [ + "base64 0.21.7", + "serde", +] + [[package]] name = "serde_json" version = "1.0.116" @@ -1805,19 +2018,74 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_with" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ad483d2ab0149d5a5ebcd9972a3852711e0153d863bf5a5d0391d28883c4a20" +dependencies = [ + "base64 0.22.0", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.6", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros", + "time", +] + +[[package]] +name = "serde_with_macros" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65569b702f41443e8bc8bbb1c5779bd0450bbe723b56198980e80ec45780bce2" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "serde_yaml" version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap", + "indexmap 2.2.6", "itoa 1.0.11", "ryu", "serde", "unsafe-libyaml", ] +[[package]] +name = "serial_test" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b4b487fe2acf240a021cf57c6b2b4903b1e78ca0ecd862a71b71d2a51fed77d" +dependencies = [ + "futures", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82fe9db325bcef1fbcde82e078a5cc4efdf787e96b3b9cf45b50b529f2083d67" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1957,6 +2225,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", + "itoa 1.0.11", "num-conv", "powerfmt", "serde", @@ -2288,6 +2557,60 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +dependencies = [ + "bumpalo", + "log", + "proc-macro2", + "quote", + "syn 2.0.60", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.99" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" + [[package]] name = "which" version = "4.4.2" @@ -2322,6 +2645,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index cc1bdc4..e776b6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,20 +17,22 @@ serde_yaml = "0.9" clap = { version = "4.0", features = ["derive"] } serde_json = "1" url = "2.5.0" -axum ={ version = "0.7.5"} +axum = { version = "0.7.5" } aws-config = { version = "1.5.5" } aws-sdk-lambda = { version = "1.42.0" } -aws-smithy-types = { version="1.2.2", features = ["serde-serialize"] } +aws-smithy-types = { version = "1.2.2", features = ["serde-serialize"] } tokio = { version = "1.39.3", features = ["full"] } tower-http = { version = "0.5.2", features = ["trace"] } -tracing-subscriber = { version= "0.3.18", features = ["json"]} -tracing ={ version = "0.1.40"} +tracing-subscriber = { version = "0.3.18", features = ["json"] } +tracing = { version = "0.1.40" } tokio-stream = "0.1.15" futures-util = "0.3.30" http-serde = "2.1.1" +aws_lambda_events = "0.16.0" [dev-dependencies] tempfile = "3.8.1" +serial_test = "3" [[bin]] name = "lambda-web-gateway" diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..1e398ff --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,81 @@ +use crate::config::{AuthMode, Config}; +use axum::http::HeaderMap; + +pub(super) fn is_authorized(headers: &HeaderMap, config: &Config) -> bool { + match config.auth_mode { + AuthMode::Open => true, + AuthMode::ApiKey => { + let api_key = headers + .get("x-api-key") + .and_then(|v| v.to_str().ok()) + .or_else(|| { + headers + .get("authorization") + .and_then(|v| v.to_str().ok().and_then(|s| s.strip_prefix("Bearer "))) + }) + .unwrap_or_default(); + + config.api_keys.contains(api_key) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashSet; + + #[test] + fn test_open_auth() { + let headers = HeaderMap::new(); + let config = Config { + auth_mode: AuthMode::Open, + ..Default::default() + }; + assert!(is_authorized(&headers, &config)); + } + + #[test] + fn test_api_key_auth() { + let config = Config { + auth_mode: AuthMode::ApiKey, + api_keys: HashSet::from(["test".to_string()]), + ..Default::default() + }; + + let mut headers = HeaderMap::new(); + headers.insert("x-api-key", "test".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + let mut headers = HeaderMap::new(); + headers.insert("x-api-key", "invalid".parse().unwrap()); + assert!(!is_authorized(&headers, &config)); + + let mut headers = HeaderMap::new(); + headers.insert("authorization", "Bearer test".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + let mut headers = HeaderMap::new(); + headers.insert("authorization", "Bearer invalid".parse().unwrap()); + assert!(!is_authorized(&headers, &config)); + } + + #[test] + fn test_multi_api_keys() { + let config = Config { + auth_mode: AuthMode::ApiKey, + api_keys: HashSet::from(["test1".to_string(), "test2".to_string()]), + ..Default::default() + }; + + let mut headers = HeaderMap::new(); + headers.insert("x-api-key", "test1".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + headers.insert("x-api-key", "test2".parse().unwrap()); + assert!(is_authorized(&headers, &config)); + + headers.insert("x-api-key", "invalid".parse().unwrap()); + assert!(!is_authorized(&headers, &config)); + } +} diff --git a/src/buffered.rs b/src/buffered.rs new file mode 100644 index 0000000..0df8526 --- /dev/null +++ b/src/buffered.rs @@ -0,0 +1,91 @@ +use crate::utils::handle_err; +use aws_lambda_events::alb::AlbTargetGroupResponse; +use aws_sdk_lambda::operation::invoke::InvokeOutput; +use axum::{body::Body, http::StatusCode, response::Response}; +use base64::{prelude::BASE64_STANDARD, Engine}; + +pub(super) fn handle_buffered_response(resp: InvokeOutput) -> Response { + // Parse the InvokeOutput payload to extract the LambdaResponse + let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref()); + let lambda_response = handle_err!( + "Deserializing lambda response", + serde_json::from_slice::(payload) + ); + + // Build the response using the extracted information + let mut resp_builder = Response::builder().status(handle_err!( + "Parse response status code", + StatusCode::from_u16(handle_err!( + "Parse response status code", + lambda_response.status_code.try_into() + )) + )); + + *handle_err!( + "Setting response headers", + resp_builder.headers_mut().ok_or("Errors in builder") + ) = lambda_response.headers; + + let mut body = lambda_response.body.map_or(vec![], |b| b.to_vec()); + if lambda_response.is_base64_encoded { + body = handle_err!("Decoding base64 body", BASE64_STANDARD.decode(body)); + } + handle_err!("Building response", resp_builder.body(Body::from(body))) +} + +#[cfg(test)] +mod tests { + use super::*; + use aws_smithy_types::Blob; + use axum::http::HeaderMap; + + #[tokio::test] + async fn test_handle_buffered_response() { + let lambda_response = AlbTargetGroupResponse { + status_code: 200, + status_description: None, + is_base64_encoded: false, + headers: { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/plain".parse().unwrap()); + headers + }, + body: Some("Hello, world!".into()), + ..Default::default() + }; + let payload = serde_json::to_vec(&lambda_response).unwrap(); + let invoke_output = InvokeOutput::builder().payload(Blob::new(payload)).build(); + + let response = handle_buffered_response(invoke_output); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain"); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!(body, "Hello, world!"); + } + + #[tokio::test] + async fn test_handle_buffered_response_base64() { + let lambda_response = AlbTargetGroupResponse { + status_code: 200, + status_description: None, + is_base64_encoded: true, + headers: { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", "text/plain".parse().unwrap()); + headers + }, + body: Some("SGVsbG8sIHdvcmxkIQ==".into()), + ..Default::default() + }; + let payload = serde_json::to_vec(&lambda_response).unwrap(); + let invoke_output = InvokeOutput::builder().payload(Blob::new(payload)).build(); + + let response = handle_buffered_response(invoke_output); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain"); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + assert_eq!(body, "Hello, world!"); + } +} diff --git a/src/config.rs b/src/config.rs index 41f05c9..0e8382d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,17 +1,14 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashSet; -use std::str::FromStr; -use std::fs; -use std::path::Path; +use std::{collections::HashSet, env, fs, path::Path, str::FromStr}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Config { pub lambda_function_name: String, - #[serde(default = "default_lambda_invoke_mode")] + #[serde(default)] pub lambda_invoke_mode: LambdaInvokeMode, #[serde(default)] pub api_keys: HashSet, - #[serde(default = "default_auth_mode")] + #[serde(default)] pub auth_mode: AuthMode, #[serde(default = "default_addr")] pub addr: String, @@ -21,9 +18,9 @@ impl Default for Config { fn default() -> Self { Self { lambda_function_name: String::new(), - lambda_invoke_mode: default_lambda_invoke_mode(), + lambda_invoke_mode: Default::default(), api_keys: HashSet::new(), - auth_mode: default_auth_mode(), + auth_mode: Default::default(), addr: default_addr(), } } @@ -40,26 +37,26 @@ impl Config { } fn apply_env_overrides(&mut self) { - if let Ok(val) = std::env::var("LAMBDA_FUNCTION_NAME") { + if let Ok(val) = env::var("LAMBDA_FUNCTION_NAME") { self.lambda_function_name = val; } if self.lambda_function_name.is_empty() { panic!("No lambda_function_name provided. Please set it in the config file or LAMBDA_FUNCTION_NAME environment variable."); } - if let Ok(val) = std::env::var("LAMBDA_INVOKE_MODE") { + if let Ok(val) = env::var("LAMBDA_INVOKE_MODE") { if let Ok(mode) = val.parse() { self.lambda_invoke_mode = mode; } } - if let Ok(val) = std::env::var("API_KEYS") { + if let Ok(val) = env::var("API_KEYS") { self.api_keys = val.split(',').filter(|s| !s.is_empty()).map(String::from).collect(); } - if let Ok(val) = std::env::var("AUTH_MODE") { + if let Ok(val) = env::var("AUTH_MODE") { if let Ok(mode) = val.parse() { self.auth_mode = mode; } } - if let Ok(val) = std::env::var("ADDR") { + if let Ok(val) = env::var("ADDR") { self.addr = val; } } @@ -71,47 +68,24 @@ impl Config { } } -#[cfg(test)] -mod tests { - include!("config_tests.rs"); -} - -fn default_auth_mode() -> AuthMode { - AuthMode::Open -} - -fn default_lambda_invoke_mode() -> LambdaInvokeMode { - LambdaInvokeMode::Buffered -} - fn default_addr() -> String { "0.0.0.0:8000".to_string() } -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)] pub enum AuthMode { + #[default] Open, ApiKey, } -impl Default for AuthMode { - fn default() -> Self { - AuthMode::Open - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)] pub enum LambdaInvokeMode { + #[default] Buffered, ResponseStream, } -impl Default for LambdaInvokeMode { - fn default() -> Self { - LambdaInvokeMode::Buffered - } -} - impl FromStr for AuthMode { type Err = String; @@ -135,3 +109,6 @@ impl FromStr for LambdaInvokeMode { } } } + +#[cfg(test)] +mod tests; diff --git a/src/config_tests.rs b/src/config/tests.rs similarity index 84% rename from src/config_tests.rs rename to src/config/tests.rs index 02bfc6e..1f1a5f6 100644 --- a/src/config_tests.rs +++ b/src/config/tests.rs @@ -1,8 +1,9 @@ use super::*; +use serial_test::serial; use std::collections::HashSet; use std::env; -use tempfile::NamedTempFile; use std::io::Write; +use tempfile::NamedTempFile; #[test] fn test_auth_mode_from_str() { @@ -15,10 +16,22 @@ fn test_auth_mode_from_str() { #[test] fn test_lambda_invoke_mode_from_str() { - assert_eq!("buffered".parse::().unwrap(), LambdaInvokeMode::Buffered); - assert_eq!("responsestream".parse::().unwrap(), LambdaInvokeMode::ResponseStream); - assert_eq!("BUFFERED".parse::().unwrap(), LambdaInvokeMode::Buffered); - assert_eq!("RESPONSESTREAM".parse::().unwrap(), LambdaInvokeMode::ResponseStream); + assert_eq!( + "buffered".parse::().unwrap(), + LambdaInvokeMode::Buffered + ); + assert_eq!( + "responsestream".parse::().unwrap(), + LambdaInvokeMode::ResponseStream + ); + assert_eq!( + "BUFFERED".parse::().unwrap(), + LambdaInvokeMode::Buffered + ); + assert_eq!( + "RESPONSESTREAM".parse::().unwrap(), + LambdaInvokeMode::ResponseStream + ); assert!("invalid".parse::().is_err()); } @@ -33,6 +46,7 @@ fn test_config_default() { } #[test] +#[serial] #[should_panic(expected = "No lambda_function_name provided")] fn test_config_panic_on_empty_lambda_function_name() { let mut config = Config::default(); @@ -40,6 +54,7 @@ fn test_config_panic_on_empty_lambda_function_name() { } #[test] +#[serial] fn test_config_apply_env_overrides() { env::set_var("LAMBDA_FUNCTION_NAME", "test-function"); env::set_var("LAMBDA_INVOKE_MODE", "responsestream"); @@ -52,7 +67,13 @@ fn test_config_apply_env_overrides() { assert_eq!(config.lambda_function_name, "test-function"); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); - assert_eq!(config.api_keys, vec!["key1", "key2"].into_iter().map(String::from).collect::>()); + assert_eq!( + config.api_keys, + vec!["key1", "key2"] + .into_iter() + .map(String::from) + .collect::>() + ); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.addr, "127.0.0.1:3000"); @@ -83,12 +104,19 @@ addr: 127.0.0.1:3000 assert_eq!(config.lambda_function_name, "test-function"); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); - assert_eq!(config.api_keys, vec!["key1", "key2"].into_iter().map(String::from).collect::>()); + assert_eq!( + config.api_keys, + vec!["key1", "key2"] + .into_iter() + .map(String::from) + .collect::>() + ); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.addr, "127.0.0.1:3000"); } #[test] +#[serial] fn test_config_load_with_env_override() { let config_content = r#" lambda_function_name: file-function @@ -111,7 +139,13 @@ addr: 0.0.0.0:8000 assert_eq!(config.lambda_function_name, "env-function"); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); - assert_eq!(config.api_keys, vec!["file-key"].into_iter().map(String::from).collect::>()); + assert_eq!( + config.api_keys, + vec!["file-key"] + .into_iter() + .map(String::from) + .collect::>() + ); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.addr, "0.0.0.0:8000"); @@ -133,13 +167,14 @@ addr: 0.0.0.0:8000 } #[test] +#[serial] fn test_config_load_invalid_file() { env::set_var("LAMBDA_FUNCTION_NAME", "env-function"); env::set_var("AUTH_MODE", "apikey"); env::set_var("LAMBDA_INVOKE_MODE", "responsestream"); let config = Config::load("non_existent_file.yaml"); - + assert_eq!(config.lambda_function_name, "env-function"); assert_eq!(config.auth_mode, AuthMode::ApiKey); assert_eq!(config.lambda_invoke_mode, LambdaInvokeMode::ResponseStream); @@ -153,6 +188,7 @@ fn test_config_load_invalid_file() { } #[test] +#[serial] fn test_config_load_invalid_yaml() { let config_content = "invalid: yaml: content"; @@ -178,12 +214,13 @@ fn test_config_load_invalid_yaml() { } #[test] +#[serial] fn test_config_load_empty_api_keys() { env::set_var("API_KEYS", ""); env::set_var("LAMBDA_FUNCTION_NAME", "test-function"); // Add this line - + let config = Config::load("non_existent_file.yaml"); - + assert!(config.api_keys.is_empty()); env::remove_var("API_KEYS"); diff --git a/src/lib.rs b/src/lib.rs index fc907d4..e4b4d83 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,371 +1,84 @@ -pub mod config; +mod auth; +mod buffered; +mod config; +mod request; +mod streaming; +mod utils; -#[cfg(test)] -mod tests { - include!("lib_tests.rs"); -} +pub use config::*; -use crate::config::{Config, LambdaInvokeMode}; -use aws_config::BehaviorVersion; -use aws_sdk_lambda::types::InvokeWithResponseStreamResponseEvent::{InvokeComplete, PayloadChunk}; -use aws_sdk_lambda::types::{InvokeResponseStreamUpdate, ResponseStreamingInvocationType}; +use auth::is_authorized; +use aws_lambda_events::query_map::QueryMap; use aws_sdk_lambda::Client; use aws_smithy_types::Blob; -use axum::body::Body; use axum::{ - body::Bytes, - extract::{Path, Query, State}, - http::{HeaderMap, Method, StatusCode}, + body::{Body, Bytes}, + extract::{Query, State}, + http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, - routing::any, - routing::get, - Router, }; -use base64::Engine; -use futures_util::stream::StreamExt; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::collections::HashMap; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; -use tower_http::trace::TraceLayer; +use buffered::handle_buffered_response; +use request::build_alb_request_body; +use std::sync::Arc; +use streaming::handle_streaming_response; +use utils::{handle_err, transform_body, whether_should_base64_encode}; #[derive(Clone)] pub struct ApplicationState { - client: Client, - config: Config, + pub client: Client, + pub config: Arc, } -pub async fn run_app() { - tracing_subscriber::fmt::init(); - - let config = Config::load("config.yaml"); - let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await; - let client = Client::new(&aws_config); - - let app_state = ApplicationState { client, config }; - - let app = Router::new() - .route("/healthz", get(health)) - .route("/", any(handler)) - .route("/*path", any(handler)) - .layer(TraceLayer::new_for_http()) - .with_state(app_state.clone()); - - let addr = &app_state.config.addr; - let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); - tracing::info!("Listening on {}", addr); - axum::serve(listener, app).await.unwrap(); -} - -async fn health() -> impl IntoResponse { +pub async fn health() -> impl IntoResponse { StatusCode::OK } -async fn handler( - path: Option>, - Query(query_string_parameters): Query>, +pub async fn invoke_lambda( State(state): State, - method: Method, - headers: HeaderMap, + Query(query): Query, + parts: Parts, body: Bytes, ) -> Response { - let client = &state.client; - let config = &state.config; - let path = "/".to_string() + path.map(|p| p.0).unwrap_or_default().as_str(); - - let http_method = method.to_string(); - - let content_type = headers - .get("content-type") - .and_then(|v| v.to_str().ok().map(Some).flatten()) - .unwrap_or_default(); - - let is_base64_encoded = match content_type { - "application/json" => false, - "application/xml" => false, - "application/javascript" => false, - _ if content_type.starts_with("text/") => false, - _ => true, - }; - - let body = if is_base64_encoded { - base64::engine::general_purpose::STANDARD.encode(body) - } else { - String::from_utf8_lossy(&body).to_string() - }; - - match config.auth_mode { - config::AuthMode::Open => {} - config::AuthMode::ApiKey => { - let api_key = headers - .get("x-api-key") - .and_then(|v| v.to_str().ok()) - .or_else(|| { - headers - .get("authorization") - .and_then(|v| v.to_str().ok().and_then(|s| s.strip_prefix("Bearer "))) - }) - .unwrap_or_default(); - - if !config.api_keys.contains(api_key) { - return Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::empty()) - .unwrap(); - } - } + if !is_authorized(&parts.headers, &state.config) { + return StatusCode::UNAUTHORIZED.into_response(); } - let lambda_request_body = json!({ - "httpMethod": http_method, - "headers": to_string_map(&headers), - "path": path, - "queryStringParameters": query_string_parameters, - "isBase64Encoded": is_base64_encoded, - "body": body, - "requestContext": { - "elb": { - "targetGroupArn": "", - }, - }, - }) - .to_string(); - - let resp = match config.lambda_invoke_mode { - LambdaInvokeMode::Buffered => { - let resp = client - .invoke() - .function_name(config.lambda_function_name.as_str()) - .payload(Blob::new(lambda_request_body)) - .send() - .await - .unwrap(); - handle_buffered_response(resp).await - } - LambdaInvokeMode::ResponseStream => { - let resp = client - .invoke_with_response_stream() - .function_name(config.lambda_function_name.as_str()) - .invocation_type(ResponseStreamingInvocationType::RequestResponse) - .payload(Blob::new(lambda_request_body)) - .send() - .await - .unwrap(); - handle_streaming_response(resp).await - } - }; - - resp -} - -fn to_string_map(headers: &HeaderMap) -> HashMap { - headers - .iter() - .map(|(k, v)| { - ( - k.as_str().to_owned(), - String::from_utf8_lossy(v.as_bytes()).into_owned(), + let should_base64_encode = whether_should_base64_encode(&parts.headers); + let body = transform_body(should_base64_encode, body); + + let lambda_request_body = handle_err!( + "Building lambda request", + build_alb_request_body(should_base64_encode, query, parts, body) + ); + + macro_rules! call_lambda { + ($action:ident) => { + handle_err!( + "Invoking lambda", + state + .client + .$action() + .function_name(state.config.lambda_function_name.as_str()) + .payload(Blob::new(lambda_request_body)) + .send() + .await ) - }) - .collect() -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(rename_all = "camelCase")] -struct LambdaResponse { - status_code: u16, - status_description: Option, - is_base64_encoded: Option, - headers: Option>, - body: String, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct MetadataPrelude { - #[serde(with = "http_serde::status_code")] - /// The HTTP status code. - pub status_code: StatusCode, - #[serde(with = "http_serde::header_map")] - /// The HTTP headers. - pub headers: HeaderMap, - /// The HTTP cookies. - pub cookies: Vec, -} - -async fn handle_buffered_response(resp: aws_sdk_lambda::operation::invoke::InvokeOutput) -> Response { - // Parse the InvokeOutput payload to extract the LambdaResponse - let payload = resp.payload().unwrap().as_ref().to_vec(); - let lambda_response: LambdaResponse = serde_json::from_slice(&payload).unwrap(); - - // Build the response using the extracted information - let mut resp_builder = Response::builder().status(StatusCode::from_u16(lambda_response.status_code).unwrap()); - - if let Some(headers) = lambda_response.headers { - for (key, value) in headers { - resp_builder = resp_builder.header(key, value); - } + }; } - let body = if lambda_response.is_base64_encoded.unwrap_or(false) { - base64::engine::general_purpose::STANDARD - .decode(lambda_response.body) - .unwrap() - } else { - lambda_response.body.into_bytes() - }; - resp_builder.body(Body::from(body)).unwrap() -} - -async fn handle_streaming_response( - mut resp: aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, -) -> Response { - let (tx, rx) = mpsc::channel(1); - let mut metadata_buffer = Vec::new(); - let mut metadata_prelude: Option = None; - let mut remaining_data = Vec::new(); - - // Step 1: Detect if metadata exists and get the first chunk - let (has_metadata, first_chunk) = detect_metadata(&mut resp).await; - - // Step 2: Process the first chunk - if let Some(chunk) = first_chunk { - if has_metadata { - metadata_buffer.extend_from_slice(&chunk); - (metadata_prelude, remaining_data) = collect_metadata(&mut resp, &mut metadata_buffer).await; - } else { - // No metadata prelude, treat first chunk as payload - remaining_data = chunk; - } + match state.config.lambda_invoke_mode { + LambdaInvokeMode::Buffered => handle_buffered_response(call_lambda!(invoke)), + LambdaInvokeMode::ResponseStream => handle_streaming_response(call_lambda!(invoke_with_response_stream)).await, } - - // Spawn task to handle remaining stream - tokio::spawn(async move { - // Send remaining data after metadata first - if !remaining_data.is_empty() { - let stream_update = InvokeResponseStreamUpdate::builder() - .payload(Blob::new(remaining_data)) - .build(); - let _ = tx.send(PayloadChunk(stream_update)).await; - } - - while let Some(event) = resp.event_stream.recv().await.unwrap() { - match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload() { - let stream_update = InvokeResponseStreamUpdate::builder().payload(data.clone()).build(); - let _ = tx.send(PayloadChunk(stream_update)).await; - } - } - InvokeComplete(_) => { - let _ = tx.send(event).await; - } - _ => {} - } - } - }); - - let stream = ReceiverStream::new(rx).map(|event| { - match event { - PayloadChunk(chunk) => { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); - Ok::<_, std::convert::Infallible>(Bytes::from(bytes)) - } else { - Ok(Bytes::default()) - } - }, - InvokeComplete(_) => Ok(Bytes::default()), - _ => Ok(Bytes::default()), // Handle other event types - } - }); - - let mut resp_builder = Response::builder(); - - if let Some(metadata_prelude) = metadata_prelude { - resp_builder = resp_builder.status(metadata_prelude.status_code); - - for (k, v) in metadata_prelude.headers.iter() { - if k != "content-length" { - resp_builder = resp_builder.header(k, v); - } - } - - for cookie in &metadata_prelude.cookies { - resp_builder = resp_builder.header("set-cookie", cookie); - } - } else { - // Default response if no metadata - resp_builder = resp_builder.status(StatusCode::OK); - resp_builder = resp_builder.header("content-type", "application/octet-stream"); - } - - resp_builder.body(Body::from_stream(stream)).unwrap() -} - -async fn detect_metadata( - resp: &mut aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, -) -> (bool, Option>) { - if let Ok(Some(event)) = resp.event_stream.recv().await { - if let PayloadChunk(chunk) = event { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); - let has_metadata = !bytes.is_empty() && bytes[0] == b'{'; - return (has_metadata, Some(bytes)); - } - } - } - (false, None) } -async fn collect_metadata( - resp: &mut aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, - metadata_buffer: &mut Vec, -) -> (Option, Vec) { - let mut metadata_prelude = None; - let mut remaining_data = Vec::new(); - - // Process the metadata_buffer first - let (prelude, remaining) = process_buffer(metadata_buffer); - if let Some(p) = prelude { - return (Some(p), remaining); - } - - // If metadata is not complete, continue processing the stream - while let Ok(Some(event)) = resp.event_stream.recv().await { - if let PayloadChunk(chunk) = event { - if let Some(data) = chunk.payload() { - let bytes = data.clone().into_inner(); - metadata_buffer.extend_from_slice(&bytes); - let (prelude, remaining) = process_buffer(metadata_buffer); - if let Some(p) = prelude { - metadata_prelude = Some(p); - remaining_data = remaining; - break; - } - } - } - } - (metadata_prelude, remaining_data) -} +#[cfg(test)] +mod tests { + use super::*; -fn process_buffer(buffer: &[u8]) -> (Option, Vec) { - let mut null_count = 0; - for (i, &byte) in buffer.iter().enumerate() { - if byte == 0 { - null_count += 1; - if null_count == 8 { - let metadata_str = String::from_utf8_lossy(&buffer[..i]); - let metadata_prelude = serde_json::from_str(&metadata_str).unwrap_or_default(); - tracing::debug!(metadata_prelude=?metadata_prelude); - // Save remaining data after metadata - let remaining_data = buffer[i + 1..].to_vec(); - return (Some(metadata_prelude), remaining_data); - } - } else { - null_count = 0; - } + #[tokio::test] + async fn test_health() { + let response = health().await.into_response(); + assert_eq!(response.status(), StatusCode::OK); } - (None, Vec::new()) } diff --git a/src/lib_tests.rs b/src/lib_tests.rs deleted file mode 100644 index b86ce9d..0000000 --- a/src/lib_tests.rs +++ /dev/null @@ -1,138 +0,0 @@ -use super::*; -// use axum::http::StatusCode; -// use aws_smithy_types::Blob; -// use std::collections::HashMap; -// use aws_sdk_lambda::types::InvokeWithResponseStreamResponseEvent; -// use aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamOutput; -// use aws_sdk_lambda::primitives::event_stream::EventReceiver; -// use aws_sdk_lambda::operation::invoke_with_response_stream::InvokeWithResponseStreamError; - -#[tokio::test] -async fn test_health() { - let response = health().await.into_response(); - assert_eq!(response.status(), StatusCode::OK); -} - -#[tokio::test] -async fn test_to_string_map() { - let mut headers = HeaderMap::new(); - headers.insert("Content-Type", "application/json".parse().unwrap()); - headers.insert("X-Custom-Header", "test-value".parse().unwrap()); - - let result = to_string_map(&headers); - - assert_eq!(result.len(), 2); - assert_eq!(result.get("content-type"), Some(&"application/json".to_string())); - assert_eq!(result.get("x-custom-header"), Some(&"test-value".to_string())); -} - -#[tokio::test] -async fn test_handle_buffered_response() { - let lambda_response = LambdaResponse { - status_code: 200, - status_description: Some("OK".to_string()), - is_base64_encoded: Some(false), - headers: Some(HashMap::from([ - ("Content-Type".to_string(), "text/plain".to_string()), - ])), - body: "Hello, World!".to_string(), - }; - - let payload = serde_json::to_vec(&lambda_response).unwrap(); - let invoke_output = aws_sdk_lambda::operation::invoke::InvokeOutput::builder() - .payload(Blob::new(payload)) - .status_code(200) - .build(); - - let response = handle_buffered_response(invoke_output).await; - - assert_eq!(response.status(), StatusCode::OK); - assert_eq!( - response.headers().get("Content-Type").unwrap(), - "text/plain" - ); - let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); - assert_eq!(body, "Hello, World!"); -} - -// #[tokio::test] -// async fn test_detect_metadata() { -// let payload = r#"{"statusCode": 200, "headers": {"Content-Type": "text/plain"}, "body": "Hello"}"#; -// let full_payload = payload.as_bytes().to_vec(); -// let chunk = InvokeWithResponseStreamResponseEvent::PayloadChunk( -// aws_sdk_lambda::types::InvokeResponseStreamUpdate::builder() -// .payload(Blob::new(full_payload.clone())) -// .build(), -// ); - -// let event_receiver = EventReceiver { -// inner: vec![chunk], -// }; - -// let mut resp = InvokeWithResponseStreamOutput::builder() -// .event_stream(event_receiver) -// .build() -// .unwrap(); - -// // Type annotation to help the compiler -// let resp: InvokeWithResponseStreamOutput = resp; - -// let (has_metadata, first_chunk) = detect_metadata(&mut resp).await; - -// assert!(has_metadata); -// assert_eq!(first_chunk.unwrap(), full_payload); -// } - -// #[tokio::test] -// async fn test_collect_metadata() { -// let payload = r#"{"statusCode": 200, "headers": {"Content-Type": "text/plain"}, "body": "Hello"}"#; -// let null_padding = vec![0u8; 8]; -// let remaining_data = b"Remaining data"; - -// let mut full_payload = payload.as_bytes().to_vec(); -// full_payload.extend_from_slice(&null_padding); -// full_payload.extend_from_slice(remaining_data); - -// let chunk = InvokeWithResponseStreamResponseEvent::PayloadChunk( -// aws_sdk_lambda::types::InvokeResponseStreamUpdate::builder() -// .payload(Blob::new(full_payload)) -// .build(), -// ); - -// let event_receiver =EventReceiver { -// inner: vec![chunk], -// }; - -// let mut resp = InvokeWithResponseStreamOutput::builder() -// .event_stream(event_receiver) -// .build() -// .unwrap(); - -// let mut metadata_buffer = Vec::new(); -// let (metadata_prelude, remaining) = collect_metadata(&mut resp, &mut metadata_buffer).await; - -// assert!(metadata_prelude.is_some()); -// let prelude = metadata_prelude.unwrap(); -// assert_eq!(prelude.status_code, StatusCode::OK); -// assert_eq!(prelude.headers.get("content-type").unwrap(), "text/plain"); -// assert_eq!(remaining, remaining_data); -// } - -// #[tokio::test] -// async fn test_process_buffer() { -// let payload = r#"{"statusCode": 200, "headers": {"Content-Type": "text/plain"}, "body": "Hello"}"#; -// let null_padding = vec![0u8; 8]; -// let remaining_data = b"Remaining data"; -// -// let mut buffer = payload.as_bytes().to_vec(); -// buffer.extend_from_slice(&null_padding); -// buffer.extend_from_slice(remaining_data); -// -// let (metadata_prelude, remaining) = process_buffer(&buffer); -// -// assert!(metadata_prelude.is_some()); -// let prelude = metadata_prelude.unwrap(); -// assert_eq!(prelude.status_code, StatusCode::OK); -// assert_eq!(prelude.headers.get("content-type").unwrap(), "text/plain"); -// assert_eq!(remaining, remaining_data); -// } diff --git a/src/main.rs b/src/main.rs index 4233464..5c57aa2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,34 @@ -use lambda_web_gateway::run_app; +use aws_config::BehaviorVersion; +use aws_sdk_lambda::Client; +use axum::{ + routing::{any, get}, + Router, +}; +use lambda_web_gateway::{health, invoke_lambda, ApplicationState, Config}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::net::TcpListener; +use tower_http::trace::TraceLayer; #[tokio::main] async fn main() { - run_app().await; + tracing_subscriber::fmt::init(); + + let state = { + let config = Arc::new(Config::load("config.yaml")); + let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await; + let client = Client::new(&aws_config); + ApplicationState { client, config } + }; + + let addr = state.config.addr.parse::().unwrap(); + let listener = TcpListener::bind(addr).await.unwrap(); + tracing::info!("Listening on {}", addr); + + let app = Router::new() + .route("/healthz", get(health)) + .route("/", any(invoke_lambda)) + .route("/*path", any(invoke_lambda)) + .layer(TraceLayer::new_for_http()) + .with_state(state); + axum::serve(listener, app).await.unwrap(); } diff --git a/src/request.rs b/src/request.rs new file mode 100644 index 0000000..51c3cf1 --- /dev/null +++ b/src/request.rs @@ -0,0 +1,99 @@ +use aws_lambda_events::query_map::QueryMap; +use axum::http::{request::Parts, HeaderMap}; +use std::collections::HashMap; + +pub(super) fn build_alb_request_body( + is_base64_encoded: bool, + query_string_parameters: QueryMap, + parts: Parts, + body: String, +) -> Result { + Ok(serde_json::json!({ + "httpMethod": parts.method.to_string(), + "path": parts.uri.path(), + "queryStringParameters": query_map_to_hash_map(query_string_parameters), + "multiValueQueryStringParameters": {}, + "headers": header_map_to_hash_map(parts.headers), + "multiValueHeaders": {}, + "requestContext": { + "elb": { + "targetGroupArn": Option::::None + } + }, + "isBase64Encoded": is_base64_encoded, + "body": body, + }) + .to_string()) + // serde_json::to_string(&AlbTargetGroupRequest { + // http_method: parts.method, + // headers: parts.headers, + // path: parts.uri.path().to_string().into(), + // query_string_parameters, + // body: body.into(), + // is_base64_encoded, + // request_context: AlbTargetGroupRequestContext { + // elb: ElbContext { target_group_arn: None }, + // }, + // // TODO: support multi-value-header mode? + // multi_value_headers: Default::default(), + // multi_value_query_string_parameters: Default::default(), + // }) +} + +// TODO: remove this after https://github.com/awslabs/aws-lambda-rust-runtime/pull/955 is merged +fn query_map_to_hash_map(map: QueryMap) -> HashMap { + map.iter() + .map(|(k, _)| { + let values = map.all(k).unwrap(); + (k.to_string(), values.iter().last().unwrap().to_string()) + }) + .collect() +} + +// TODO: remove this after https://github.com/awslabs/aws-lambda-rust-runtime/pull/955 is merged +fn header_map_to_hash_map(map: HeaderMap) -> HashMap { + map.iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string())) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::{request::Builder, Method}; + use base64::{prelude::BASE64_STANDARD, Engine}; + use std::collections::HashMap; + + #[test] + fn test_alb_body() { + let (parts, body) = Builder::new() + .method(Method::GET) + .uri("https://example.com/?k=v") + .header("key", "value") + .body("Hello, world!") + .unwrap() + .into_parts(); + let query = HashMap::from([("k".to_string(), "v".to_string())]).into(); + + let expected = "{\"body\":\"Hello, world!\",\"headers\":{\"key\":\"value\"},\"httpMethod\":\"GET\",\"isBase64Encoded\":false,\"multiValueHeaders\":{},\"multiValueQueryStringParameters\":{},\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}}}"; + assert_eq!( + build_alb_request_body(false, query, parts, body.into()).unwrap(), + expected + ); + } + + #[test] + fn test_alb_body_base64() { + let (parts, body) = Builder::new() + .method(Method::GET) + .uri("https://example.com/?k=v") + .header("key", "value") + .body(BASE64_STANDARD.encode("Hello, world!")) + .unwrap() + .into_parts(); + let query = HashMap::from([("k".to_string(), "v".to_string())]).into(); + + let expected = "{\"body\":\"SGVsbG8sIHdvcmxkIQ==\",\"headers\":{\"key\":\"value\"},\"httpMethod\":\"GET\",\"isBase64Encoded\":true,\"multiValueHeaders\":{},\"multiValueQueryStringParameters\":{},\"path\":\"/\",\"queryStringParameters\":{\"k\":\"v\"},\"requestContext\":{\"elb\":{\"targetGroupArn\":null}}}"; + assert_eq!(build_alb_request_body(true, query, parts, body).unwrap(), expected); + } +} diff --git a/src/streaming.rs b/src/streaming.rs new file mode 100644 index 0000000..f157d84 --- /dev/null +++ b/src/streaming.rs @@ -0,0 +1,214 @@ +use crate::utils::handle_err; +use aws_sdk_lambda::{ + operation::invoke_with_response_stream::InvokeWithResponseStreamOutput, + types::{InvokeResponseStreamUpdate, InvokeWithResponseStreamResponseEvent}, +}; +use axum::{ + body::Body, + http::{response::Builder, HeaderMap, StatusCode}, + response::Response, +}; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; +use tokio_stream::{wrappers::ReceiverStream, StreamExt}; +use InvokeWithResponseStreamResponseEvent::*; + +// TODO: contribute to `lambda_runtime` crate to make this struct derive Deserialize +#[derive(Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct MetadataPrelude { + /// The HTTP status code. + #[serde(with = "http_serde::status_code")] + pub status_code: StatusCode, + /// The HTTP headers. + #[serde(with = "http_serde::header_map")] + pub headers: HeaderMap, + /// The HTTP cookies. + pub cookies: Vec, +} + +pub(super) async fn handle_streaming_response(mut resp: InvokeWithResponseStreamOutput) -> Response { + // collect metadata + let (metadata, buffer) = { + let mut buffer = vec![]; + loop { + let next = handle_err!("Receiving response stream", resp.event_stream.recv().await); + if let Some(PayloadChunk(InvokeResponseStreamUpdate { + payload: Some(data), .. + })) = next + { + buffer.extend_from_slice(&data.into_inner()); + + // actually this is only required for the first chunk + // but this is cheap, so we call it in the loop to simplify the flow + if !detect_metadata(&buffer) { + break (None, buffer); + } + + if let Some((prelude, remaining)) = try_parse_metadata(&buffer) { + break (Some(prelude), remaining.into()); + } + } else { + // no more chunks + break (None, buffer); + } + } + }; + + let builder = create_response_builder(metadata); + + // Spawn task to handle remaining stream + let (tx, rx) = mpsc::channel(1); + tokio::spawn(async move { + // Send remaining data after metadata first + if !buffer.is_empty() { + tx.send(Ok(buffer)).await.ok(); + } + + loop { + match resp.event_stream.recv().await { + Err(e) => { + tx.send(Err(e)).await.ok(); + } + Ok(e) => { + if let Some(update) = e { + match update { + PayloadChunk(chunk) => { + if let Some(data) = chunk.payload { + let bytes = data.into_inner(); + if !bytes.is_empty() { + tx.send(Ok(bytes)).await.ok(); + } + } + // else, no data in the chunk, just ignore + } + InvokeComplete(_) => { + break; + } + _ => { + tracing::warn!("Unhandled event type: {:?}", update); + } + } + } else { + break; + } + } + } + } + }); + + handle_err!( + "Building response", + builder.body(Body::from_stream( + ReceiverStream::new(rx).map(|res| res.map(Bytes::from)) + )) + ) +} + +#[inline] +fn detect_metadata(bytes: &[u8]) -> bool { + bytes.first() == Some(&b'{') +} + +/// If metadata prelude is found, return the metadata prelude and the remaining data. +fn try_parse_metadata(buffer: &[u8]) -> Option<(MetadataPrelude, &[u8])> { + let mut null_count = 0; + for (i, &byte) in buffer.iter().enumerate() { + if byte != 0 { + null_count = 0; + continue; + } + null_count += 1; + if null_count == 8 { + // now we have 8 continuous null bytes + let metadata_str = String::from_utf8_lossy(&buffer[..i - 7]); + // TODO: handle invalid metadata prelude + let metadata_prelude = serde_json::from_str(&metadata_str).unwrap_or_default(); + tracing::debug!(metadata_prelude=?metadata_prelude); + return Some((metadata_prelude, &buffer[i + 1..])); + } + } + None +} + +fn create_response_builder(metadata: Option) -> Builder { + if let Some(metadata) = metadata { + let mut builder = Response::builder().status(metadata.status_code); + + // apply all headers except content-length + { + let headers = builder.headers_mut().unwrap(); + *headers = metadata.headers; + headers.remove("content-length"); + } + + for cookie in &metadata.cookies { + builder = builder.header("set-cookie", cookie); + } + + builder + } else { + // Default response if no metadata + Response::builder() + .status(StatusCode::OK) + .header("content-type", "application/octet-stream") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_metadata() { + assert!(detect_metadata(b"{\"statusCode\":200,\"headers\":{}}")); + assert!(!detect_metadata(b"Hello, world!")); + } + + #[test] + fn test_try_parse_metadata() { + // incomplete + assert!(try_parse_metadata(b"{\"statusCod").is_none()); + assert!(try_parse_metadata(b"{\"statusCode\":200,\"headers\":{}}\0\0\0").is_none()); + + // complete + let (metadata_prelude, remaining) = + try_parse_metadata(b"{\"statusCode\":200,\"headers\":{}}\0\0\0\0\0\0\0\0Hello, world!").unwrap(); + assert_eq!(metadata_prelude.status_code, StatusCode::OK); + assert_eq!(metadata_prelude.headers.len(), 0); + assert_eq!(remaining, b"Hello, world!"); + } + + #[test] + fn test_create_response_builder() { + let metadata = MetadataPrelude { + status_code: StatusCode::INTERNAL_SERVER_ERROR, + headers: { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "text/plain".parse().unwrap()); + headers.insert("content-length", "0".parse().unwrap()); + headers + }, + cookies: vec!["cookie1".to_string(), "cookie2".to_string()], + }; + let builder = create_response_builder(Some(metadata)); + let response = builder.body(Body::empty()).unwrap(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(response.headers().get("content-type").unwrap(), "text/plain"); + assert_eq!(response.headers().get("content-length"), None); + assert_eq!( + response.headers().get_all("set-cookie").iter().collect::>(), + &["cookie1", "cookie2"] + ); + + let builder = create_response_builder(None); + let response = builder.body(Body::empty()).unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("content-type").unwrap(), + "application/octet-stream" + ); + assert_eq!(response.headers().get("content-length"), None); + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..b17cf31 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,75 @@ +use axum::http::HeaderMap; +use base64::{prelude::BASE64_STANDARD, Engine}; +use bytes::Bytes; + +macro_rules! handle_err { + ($name:expr, $result:expr) => {{ + match $result { + Ok(v) => v, + Err(e) => { + tracing::error!("{}: {:?}", $name, e); + return Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap(); + } + } + }}; +} +pub(crate) use handle_err; + +pub(super) fn whether_should_base64_encode(headers: &HeaderMap) -> bool { + let content_type = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or_default(); + + match content_type { + "application/json" => false, + "application/xml" => false, + "application/javascript" => false, + _ if content_type.starts_with("text/") => false, + _ => true, + } +} + +pub(super) fn transform_body(should_base64_encode: bool, body: Bytes) -> String { + if should_base64_encode { + BASE64_STANDARD.encode(body) + } else { + String::from_utf8_lossy(&body).to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_whether_base64_encoded() { + let mut headers = HeaderMap::new(); + headers.insert("content-type", "application/json".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "application/xml".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "application/javascript".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "text/html".parse().unwrap()); + assert!(!whether_should_base64_encode(&headers)); + + headers.insert("content-type", "image/png".parse().unwrap()); + assert!(whether_should_base64_encode(&headers)); + } + + #[test] + fn test_transform_body() { + let body = Bytes::from("Hello, world!"); + assert_eq!(transform_body(false, body.clone()), "Hello, world!"); + + let base64_body = Bytes::from(BASE64_STANDARD.encode(&body)); + assert_eq!(transform_body(true, body), base64_body); + } +}