diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 962fd51ccf4..75301a8928c 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -177,6 +177,7 @@ def __init__( body: Union[str, bytes, None] = None, headers: Optional[Dict[str, Union[str, List[str]]]] = None, cookies: Optional[List[Cookie]] = None, + compress: Optional[bool] = None, ): """ @@ -199,6 +200,7 @@ def __init__( self.base64_encoded = False self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} self.cookies = cookies or [] + self.compress = compress if content_type: self.headers.setdefault("Content-Type", content_type) @@ -233,6 +235,38 @@ def _add_cache_control(self, cache_control: str): cache_control = cache_control if self.response.status_code == 200 else "no-cache" self.response.headers["Cache-Control"] = cache_control + @staticmethod + def _has_compression_enabled( + route_compression: bool, response_compression: Optional[bool], event: BaseProxyEvent + ) -> bool: + """ + Checks if compression is enabled. + + NOTE: Response compression takes precedence. + + Parameters + ---------- + route_compression: bool, optional + A boolean indicating whether compression is enabled or not in the route setting. + response_compression: bool, optional + A boolean indicating whether compression is enabled or not in the response setting. + event: BaseProxyEvent + The event object containing the request details. + + Returns + ------- + bool + True if compression is enabled and the "gzip" encoding is accepted, False otherwise. + """ + encoding: str = event.get_header_value(name="accept-encoding", default_value="", case_sensitive=False) # type: ignore[assignment] # noqa: E501 + if "gzip" in encoding: + if response_compression is not None: + return response_compression # e.g., Response(compress=False/True)) + if route_compression: + return True # e.g., @app.get(compress=True) + + return False + def _compress(self): """Compress the response body, but only if `Accept-Encoding` headers includes gzip.""" self.response.headers["Content-Encoding"] = "gzip" @@ -250,7 +284,9 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): self._add_cors(event, cors or CORSConfig()) if self.route.cache_control: self._add_cache_control(self.route.cache_control) - if self.route.compress and "gzip" in (event.get_header_value("accept-encoding", "") or ""): + if self._has_compression_enabled( + route_compression=self.route.compress, response_compression=self.response.compress, event=event + ): self._compress() def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index a862c7da454..c778040906d 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -154,6 +154,7 @@ def get_query_string_value(self, name: str, default_value: Optional[str] = None) query_string_parameters=self.query_string_parameters, name=name, default_value=default_value ) + # Maintenance: missing @overload to ensure return type is a str when default_value is set def get_header_value( self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False ) -> Optional[str]: diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 0fafba80b47..ef544a57d0f 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -360,15 +360,24 @@ You can use the `Response` class to have full control over the response. For exa ### Compress -You can compress with gzip and base64 encode your responses via `compress` parameter. +You can compress with gzip and base64 encode your responses via `compress` parameter. You have the option to pass the `compress` parameter when working with a specific route or using the Response object. + +???+ info + The `compress` parameter used in the Response object takes precedence over the one used in the route. ???+ warning The client must send the `Accept-Encoding` header, otherwise a normal response will be sent. -=== "compressing_responses.py" +=== "compressing_responses_using_route.py" ```python hl_lines="17 27" - --8<-- "examples/event_handler_rest/src/compressing_responses.py" + --8<-- "examples/event_handler_rest/src/compressing_responses_using_route.py" + ``` + +=== "compressing_responses_using_response.py" + + ```python hl_lines="24" + --8<-- "examples/event_handler_rest/src/compressing_responses_using_response.py" ``` === "compressing_responses.json" diff --git a/examples/event_handler_rest/src/compressing_responses_using_response.py b/examples/event_handler_rest/src/compressing_responses_using_response.py new file mode 100644 index 00000000000..b777ab40af9 --- /dev/null +++ b/examples/event_handler_rest/src/compressing_responses_using_response.py @@ -0,0 +1,31 @@ +import requests + +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.event_handler import ( + APIGatewayRestResolver, + Response, + content_types, +) +from aws_lambda_powertools.logging import correlation_paths +from aws_lambda_powertools.utilities.typing import LambdaContext + +tracer = Tracer() +logger = Logger() +app = APIGatewayRestResolver() + + +@app.get("/todos") +@tracer.capture_method +def get_todos(): + todos: requests.Response = requests.get("https://jsonplaceholder.typicode.com/todos") + todos.raise_for_status() + + # for brevity, we'll limit to the first 10 only + return Response(status_code=200, content_type=content_types.APPLICATION_JSON, body=todos.json()[:10], compress=True) + + +# You can continue to use other utilities just as before +@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) +@tracer.capture_lambda_handler +def lambda_handler(event: dict, context: LambdaContext) -> dict: + return app.resolve(event, context) diff --git a/examples/event_handler_rest/src/compressing_responses.py b/examples/event_handler_rest/src/compressing_responses_using_route.py similarity index 100% rename from examples/event_handler_rest/src/compressing_responses.py rename to examples/event_handler_rest/src/compressing_responses_using_route.py diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index c17422f8d94..9d2d3c5184e 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -366,6 +366,58 @@ def test_cors_preflight_body_is_empty_not_null(): assert result["body"] == "" +def test_override_route_compress_parameter(): + # GIVEN a function that has compress=True + # AND an event with a "Accept-Encoding" that include gzip + # AND the Response object with compress=False + app = ApiGatewayResolver() + mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + expected_value = '{"test": "value"}' + + @app.get("/my/request", compress=True) + def with_compression() -> Response: + return Response(200, content_types.APPLICATION_JSON, expected_value, compress=False) + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + result = handler(mock_event, None) + + # THEN then the response is not compressed + assert result["isBase64Encoded"] is False + assert result["body"] == expected_value + assert result["multiValueHeaders"].get("Content-Encoding") is None + + +def test_response_with_compress_enabled(): + # GIVEN a function + # AND an event with a "Accept-Encoding" that include gzip + # AND the Response object with compress=True + app = ApiGatewayResolver() + mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + expected_value = '{"test": "value"}' + + @app.get("/my/request") + def route_without_compression() -> Response: + return Response(200, content_types.APPLICATION_JSON, expected_value, compress=True) + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + result = handler(mock_event, None) + + # THEN then gzip the response and base64 encode as a string + assert result["isBase64Encoded"] is True + body = result["body"] + assert isinstance(body, str) + decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8") + assert decompress == expected_value + headers = result["multiValueHeaders"] + assert headers["Content-Encoding"] == ["gzip"] + + def test_compress(): # GIVEN a function that has compress=True # AND an event with a "Accept-Encoding" that include gzip