Skip to content

Optimize code and performance. #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8148161
chore: Add serial_test dependency and update config tests to run seri…
DiscreteTom Oct 31, 2024
1c78815
chore: Simplify default values for Config struct and enums
DiscreteTom Oct 31, 2024
96a6368
chore: Format code
DiscreteTom Oct 31, 2024
35d8453
chore: simplify code following clippy
DiscreteTom Oct 31, 2024
8e560e2
chore: reorganize test module structure and prevent `include!` macro
DiscreteTom Oct 31, 2024
4a71321
perf: prevent unnecessary clone
DiscreteTom Oct 31, 2024
1f14c03
perf: use Arc for Config in ApplicationState for cheaper clone
DiscreteTom Oct 31, 2024
25cce5a
chore: add serial attribute to test_config_panic_on_empty_lambda_func…
DiscreteTom Oct 31, 2024
e779f11
chore: move tests to the end of files
DiscreteTom Jan 2, 2025
a87ec8b
chore: optimize import structure, simplify typing
DiscreteTom Jan 2, 2025
34860f4
perf: better performance and error handling
DiscreteTom Jan 2, 2025
55825dc
chore: split mods
DiscreteTom Jan 2, 2025
493f942
chore: simplify code with aws_lambda_events crate
DiscreteTom Jan 2, 2025
5616d4c
chore: update comments
DiscreteTom Jan 3, 2025
2900971
chore: remove unnecessary statements
DiscreteTom Jan 3, 2025
f8eb4ba
chore: extract mods and functions
DiscreteTom Jan 3, 2025
c5d5650
chore: inline `run_app` to `main`
DiscreteTom Jan 3, 2025
0b2e528
chore: rename `whether_base64_encoded` to `whether_should_base64_encode`
DiscreteTom Jan 3, 2025
a69fbb9
fix: `handle_buffered_response` should be sync
DiscreteTom Jan 3, 2025
08e3aca
tests: add unit tests
DiscreteTom Jan 3, 2025
f22ee04
chore: optimize code for streaming mod, prevent unnecessary clone
DiscreteTom Jan 6, 2025
1b58984
chore: optimize stream handling
DiscreteTom Jan 6, 2025
515542a
tests: add unit tests for `try_parse_metadata`
DiscreteTom Jan 6, 2025
d5c1fcf
chore: refactor streaming to simplify code
DiscreteTom Jan 6, 2025
3431658
chore: extract fn `create_response_builder` for streaming
DiscreteTom Jan 6, 2025
9a46da0
tests: add unit tests for `create_response_builder`
DiscreteTom Jan 6, 2025
4ddea29
chore: better streaming error handling
DiscreteTom Jan 6, 2025
7b403b3
tests: fix unit tests for auth
DiscreteTom Jan 7, 2025
d0c4a5f
tests: fix unit tests for transform_body
DiscreteTom Jan 7, 2025
493a8b7
fix: fix queryStringParameters in build_alb_request_body
DiscreteTom Jan 13, 2025
297032e
chore: optimize code following clippy
DiscreteTom Jan 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 337 additions & 5 deletions Cargo.lock

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@ serde_yaml = "0.9"
clap = { version = "4.0", features = ["derive"] }
serde_json = "1"
url = "2.5.0"
axum ={ version = "0.7.5"}
axum = { version = "0.7.5" }
aws-config = { version = "1.5.5" }
aws-sdk-lambda = { version = "1.42.0" }
aws-smithy-types = { version="1.2.2", features = ["serde-serialize"] }
aws-smithy-types = { version = "1.2.2", features = ["serde-serialize"] }
tokio = { version = "1.39.3", features = ["full"] }
tower-http = { version = "0.5.2", features = ["trace"] }
tracing-subscriber = { version= "0.3.18", features = ["json"]}
tracing ={ version = "0.1.40"}
tracing-subscriber = { version = "0.3.18", features = ["json"] }
tracing = { version = "0.1.40" }
tokio-stream = "0.1.15"
futures-util = "0.3.30"
http-serde = "2.1.1"
aws_lambda_events = "0.16.0"

[dev-dependencies]
tempfile = "3.8.1"
serial_test = "3"

[[bin]]
name = "lambda-web-gateway"
Expand Down
81 changes: 81 additions & 0 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::config::{AuthMode, Config};
use axum::http::HeaderMap;

pub(super) fn is_authorized(headers: &HeaderMap, config: &Config) -> bool {
match config.auth_mode {
AuthMode::Open => true,
AuthMode::ApiKey => {
let api_key = headers
.get("x-api-key")
.and_then(|v| v.to_str().ok())
.or_else(|| {
headers
.get("authorization")
.and_then(|v| v.to_str().ok().and_then(|s| s.strip_prefix("Bearer ")))
})
.unwrap_or_default();

config.api_keys.contains(api_key)
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;

#[test]
fn test_open_auth() {
let headers = HeaderMap::new();
let config = Config {
auth_mode: AuthMode::Open,
..Default::default()
};
assert!(is_authorized(&headers, &config));
}

#[test]
fn test_api_key_auth() {
let config = Config {
auth_mode: AuthMode::ApiKey,
api_keys: HashSet::from(["test".to_string()]),
..Default::default()
};

let mut headers = HeaderMap::new();
headers.insert("x-api-key", "test".parse().unwrap());
assert!(is_authorized(&headers, &config));

let mut headers = HeaderMap::new();
headers.insert("x-api-key", "invalid".parse().unwrap());
assert!(!is_authorized(&headers, &config));

let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer test".parse().unwrap());
assert!(is_authorized(&headers, &config));

let mut headers = HeaderMap::new();
headers.insert("authorization", "Bearer invalid".parse().unwrap());
assert!(!is_authorized(&headers, &config));
}

#[test]
fn test_multi_api_keys() {
let config = Config {
auth_mode: AuthMode::ApiKey,
api_keys: HashSet::from(["test1".to_string(), "test2".to_string()]),
..Default::default()
};

let mut headers = HeaderMap::new();
headers.insert("x-api-key", "test1".parse().unwrap());
assert!(is_authorized(&headers, &config));

headers.insert("x-api-key", "test2".parse().unwrap());
assert!(is_authorized(&headers, &config));

headers.insert("x-api-key", "invalid".parse().unwrap());
assert!(!is_authorized(&headers, &config));
}
}
91 changes: 91 additions & 0 deletions src/buffered.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use crate::utils::handle_err;
use aws_lambda_events::alb::AlbTargetGroupResponse;
use aws_sdk_lambda::operation::invoke::InvokeOutput;
use axum::{body::Body, http::StatusCode, response::Response};
use base64::{prelude::BASE64_STANDARD, Engine};

pub(super) fn handle_buffered_response(resp: InvokeOutput) -> Response {
// Parse the InvokeOutput payload to extract the LambdaResponse
let payload = resp.payload().map_or(&[] as &[u8], |v| v.as_ref());
let lambda_response = handle_err!(
"Deserializing lambda response",
serde_json::from_slice::<AlbTargetGroupResponse>(payload)
);

// Build the response using the extracted information
let mut resp_builder = Response::builder().status(handle_err!(
"Parse response status code",
StatusCode::from_u16(handle_err!(
"Parse response status code",
lambda_response.status_code.try_into()
))
));

*handle_err!(
"Setting response headers",
resp_builder.headers_mut().ok_or("Errors in builder")
) = lambda_response.headers;

let mut body = lambda_response.body.map_or(vec![], |b| b.to_vec());
if lambda_response.is_base64_encoded {
body = handle_err!("Decoding base64 body", BASE64_STANDARD.decode(body));
}
handle_err!("Building response", resp_builder.body(Body::from(body)))
}

#[cfg(test)]
mod tests {
use super::*;
use aws_smithy_types::Blob;
use axum::http::HeaderMap;

#[tokio::test]
async fn test_handle_buffered_response() {
let lambda_response = AlbTargetGroupResponse {
status_code: 200,
status_description: None,
is_base64_encoded: false,
headers: {
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "text/plain".parse().unwrap());
headers
},
body: Some("Hello, world!".into()),
..Default::default()
};
let payload = serde_json::to_vec(&lambda_response).unwrap();
let invoke_output = InvokeOutput::builder().payload(Blob::new(payload)).build();

let response = handle_buffered_response(invoke_output);

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain");
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(body, "Hello, world!");
}

#[tokio::test]
async fn test_handle_buffered_response_base64() {
let lambda_response = AlbTargetGroupResponse {
status_code: 200,
status_description: None,
is_base64_encoded: true,
headers: {
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "text/plain".parse().unwrap());
headers
},
body: Some("SGVsbG8sIHdvcmxkIQ==".into()),
..Default::default()
};
let payload = serde_json::to_vec(&lambda_response).unwrap();
let invoke_output = InvokeOutput::builder().payload(Blob::new(payload)).build();

let response = handle_buffered_response(invoke_output);

assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.headers().get("Content-Type").unwrap(), "text/plain");
let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
assert_eq!(body, "Hello, world!");
}
}
57 changes: 17 additions & 40 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::str::FromStr;
use std::fs;
use std::path::Path;
use std::{collections::HashSet, env, fs, path::Path, str::FromStr};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Config {
pub lambda_function_name: String,
#[serde(default = "default_lambda_invoke_mode")]
#[serde(default)]
pub lambda_invoke_mode: LambdaInvokeMode,
#[serde(default)]
pub api_keys: HashSet<String>,
#[serde(default = "default_auth_mode")]
#[serde(default)]
pub auth_mode: AuthMode,
#[serde(default = "default_addr")]
pub addr: String,
Expand All @@ -21,9 +18,9 @@ impl Default for Config {
fn default() -> Self {
Self {
lambda_function_name: String::new(),
lambda_invoke_mode: default_lambda_invoke_mode(),
lambda_invoke_mode: Default::default(),
api_keys: HashSet::new(),
auth_mode: default_auth_mode(),
auth_mode: Default::default(),
addr: default_addr(),
}
}
Expand All @@ -40,26 +37,26 @@ impl Config {
}

fn apply_env_overrides(&mut self) {
if let Ok(val) = std::env::var("LAMBDA_FUNCTION_NAME") {
if let Ok(val) = env::var("LAMBDA_FUNCTION_NAME") {
self.lambda_function_name = val;
}
if self.lambda_function_name.is_empty() {
panic!("No lambda_function_name provided. Please set it in the config file or LAMBDA_FUNCTION_NAME environment variable.");
}
if let Ok(val) = std::env::var("LAMBDA_INVOKE_MODE") {
if let Ok(val) = env::var("LAMBDA_INVOKE_MODE") {
if let Ok(mode) = val.parse() {
self.lambda_invoke_mode = mode;
}
}
if let Ok(val) = std::env::var("API_KEYS") {
if let Ok(val) = env::var("API_KEYS") {
self.api_keys = val.split(',').filter(|s| !s.is_empty()).map(String::from).collect();
}
if let Ok(val) = std::env::var("AUTH_MODE") {
if let Ok(val) = env::var("AUTH_MODE") {
if let Ok(mode) = val.parse() {
self.auth_mode = mode;
}
}
if let Ok(val) = std::env::var("ADDR") {
if let Ok(val) = env::var("ADDR") {
self.addr = val;
}
}
Expand All @@ -71,47 +68,24 @@ impl Config {
}
}

#[cfg(test)]
mod tests {
include!("config_tests.rs");
}

fn default_auth_mode() -> AuthMode {
AuthMode::Open
}

fn default_lambda_invoke_mode() -> LambdaInvokeMode {
LambdaInvokeMode::Buffered
}

fn default_addr() -> String {
"0.0.0.0:8000".to_string()
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum AuthMode {
#[default]
Open,
ApiKey,
}

impl Default for AuthMode {
fn default() -> Self {
AuthMode::Open
}
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum LambdaInvokeMode {
#[default]
Buffered,
ResponseStream,
}

impl Default for LambdaInvokeMode {
fn default() -> Self {
LambdaInvokeMode::Buffered
}
}

impl FromStr for AuthMode {
type Err = String;

Expand All @@ -135,3 +109,6 @@ impl FromStr for LambdaInvokeMode {
}
}
}

#[cfg(test)]
mod tests;
Loading