diff --git a/aws_lambda_powertools/metrics/metrics.py b/aws_lambda_powertools/metrics/metrics.py index cb970fcfdc0..976380ab6a9 100644 --- a/aws_lambda_powertools/metrics/metrics.py +++ b/aws_lambda_powertools/metrics/metrics.py @@ -1,11 +1,12 @@ # NOTE: keeps for compatibility from __future__ import annotations -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict from aws_lambda_powertools.metrics.base import MetricResolution, MetricUnit from aws_lambda_powertools.metrics.provider.cloudwatch_emf.cloudwatch import AmazonCloudWatchEMFProvider from aws_lambda_powertools.metrics.provider.cloudwatch_emf.types import CloudWatchEMFOutput +from aws_lambda_powertools.shared.types import AnyCallableT class Metrics: @@ -129,16 +130,18 @@ def flush_metrics(self, raise_on_empty_metrics: bool = False) -> None: def log_metrics( self, - lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None, + lambda_handler: AnyCallableT | None = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, default_dimensions: Dict[str, str] | None = None, + **kwargs, ): return self.provider.log_metrics( lambda_handler=lambda_handler, capture_cold_start_metric=capture_cold_start_metric, raise_on_empty_metrics=raise_on_empty_metrics, default_dimensions=default_dimensions, + **kwargs, ) def set_default_dimensions(self, **dimensions) -> None: diff --git a/aws_lambda_powertools/metrics/provider/base.py b/aws_lambda_powertools/metrics/provider/base.py index edcc8e07ec3..ea61a5ec4d7 100644 --- a/aws_lambda_powertools/metrics/provider/base.py +++ b/aws_lambda_powertools/metrics/provider/base.py @@ -3,9 +3,10 @@ import functools import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional +from typing import Any from aws_lambda_powertools.metrics.provider import cold_start +from aws_lambda_powertools.shared.types import AnyCallableT from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -140,7 +141,7 @@ def add_cold_start_metric(self, context: LambdaContext) -> Any: def log_metrics( self, - lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None, + lambda_handler: AnyCallableT | None = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, **kwargs, diff --git a/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py b/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py index 16be60112c3..e1366a8a725 100644 --- a/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py +++ b/aws_lambda_powertools/metrics/provider/cloudwatch_emf/cloudwatch.py @@ -7,7 +7,7 @@ import os import warnings from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List from aws_lambda_powertools.metrics.base import single_metric from aws_lambda_powertools.metrics.exceptions import MetricValueError, SchemaValidationError @@ -22,6 +22,7 @@ from aws_lambda_powertools.metrics.types import MetricNameUnitResolution from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import resolve_env_var_choice +from aws_lambda_powertools.shared.types import AnyCallableT from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -334,7 +335,7 @@ def flush_metrics(self, raise_on_empty_metrics: bool = False) -> None: def log_metrics( self, - lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None, + lambda_handler: AnyCallableT | None = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, **kwargs, diff --git a/aws_lambda_powertools/metrics/provider/datadog/datadog.py b/aws_lambda_powertools/metrics/provider/datadog/datadog.py index 320ba9317bf..1e527a1ddb9 100644 --- a/aws_lambda_powertools/metrics/provider/datadog/datadog.py +++ b/aws_lambda_powertools/metrics/provider/datadog/datadog.py @@ -7,13 +7,14 @@ import re import time import warnings -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List from aws_lambda_powertools.metrics.exceptions import MetricValueError, SchemaValidationError from aws_lambda_powertools.metrics.provider import BaseProvider from aws_lambda_powertools.metrics.provider.datadog.warnings import DatadogDataValidationWarning from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import resolve_env_var_choice +from aws_lambda_powertools.shared.types import AnyCallableT from aws_lambda_powertools.utilities.typing import LambdaContext METRIC_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_.]+$") @@ -226,7 +227,7 @@ def add_cold_start_metric(self, context: LambdaContext) -> None: def log_metrics( self, - lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None, + lambda_handler: AnyCallableT | None = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, **kwargs, diff --git a/aws_lambda_powertools/metrics/provider/datadog/metrics.py b/aws_lambda_powertools/metrics/provider/datadog/metrics.py index 3ee4dc2f835..7539b0336be 100644 --- a/aws_lambda_powertools/metrics/provider/datadog/metrics.py +++ b/aws_lambda_powertools/metrics/provider/datadog/metrics.py @@ -1,9 +1,10 @@ # NOTE: keeps for compatibility from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List from aws_lambda_powertools.metrics.provider.datadog.datadog import DatadogProvider +from aws_lambda_powertools.shared.types import AnyCallableT class DatadogMetrics: @@ -90,7 +91,7 @@ def flush_metrics(self, raise_on_empty_metrics: bool = False) -> None: def log_metrics( self, - lambda_handler: Callable[[Dict, Any], Any] | Optional[Callable[[Dict, Any, Optional[Dict]], Any]] = None, + lambda_handler: AnyCallableT | None = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, default_tags: Dict[str, Any] | None = None, diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index 55874c56eff..a66fed3014d 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -124,9 +124,9 @@ def final_decorator(func: Optional[Callable] = None, **kwargs: Any): ) @functools.wraps(func) - def wrapper(event, context): + def wrapper(event, context, **handler_kwargs): try: - middleware = functools.partial(decorator, func, event, context, **kwargs) + middleware = functools.partial(decorator, func, event, context, **kwargs, **handler_kwargs) if trace_execution: tracer = Tracer(auto_patch=False) with tracer.provider.in_subsegment(name=f"## {decorator.__qualname__}"): diff --git a/aws_lambda_powertools/utilities/idempotency/idempotency.py b/aws_lambda_powertools/utilities/idempotency/idempotency.py index f38a860a6c7..94306dfa2c5 100644 --- a/aws_lambda_powertools/utilities/idempotency/idempotency.py +++ b/aws_lambda_powertools/utilities/idempotency/idempotency.py @@ -66,7 +66,7 @@ def idempotent( """ if os.getenv(constants.IDEMPOTENCY_DISABLED_ENV): - return handler(event, context) + return handler(event, context, **kwargs) config = config or IdempotencyConfig() config.register_lambda_context(context) @@ -91,6 +91,7 @@ def idempotent_function( persistence_store: BasePersistenceLayer, config: Optional[IdempotencyConfig] = None, output_serializer: Optional[Union[BaseIdempotencySerializer, Type[BaseIdempotencyModelSerializer]]] = None, + **kwargs: Any, ) -> Any: """ Decorator to handle idempotency of any function @@ -136,6 +137,7 @@ def process_order(customer_id: str, order: dict, **kwargs): persistence_store=persistence_store, config=config, output_serializer=output_serializer, + **kwargs, ), ) diff --git a/aws_lambda_powertools/utilities/parser/parser.py b/aws_lambda_powertools/utilities/parser/parser.py index a45e4880b18..9d3f8afa2c8 100644 --- a/aws_lambda_powertools/utilities/parser/parser.py +++ b/aws_lambda_powertools/utilities/parser/parser.py @@ -15,11 +15,12 @@ @lambda_handler_decorator def event_parser( - handler: Callable[[Any, LambdaContext], EventParserReturnType], + handler: Callable[..., EventParserReturnType], event: Dict[str, Any], context: LambdaContext, model: Optional[Type[Model]] = None, envelope: Optional[Type[Envelope]] = None, + **kwargs: Any, ) -> EventParserReturnType: """Lambda handler decorator to parse & validate events using Pydantic models @@ -93,9 +94,13 @@ def handler(event: Order, context: LambdaContext): "or as the type hint of `event` in the handler that it wraps", ) - parsed_event = parse(event=event, model=model, envelope=envelope) if envelope else parse(event=event, model=model) + if envelope: + parsed_event = parse(event=event, model=model, envelope=envelope) + else: + parsed_event = parse(event=event, model=model) + logger.debug(f"Calling handler {handler.__name__}") - return handler(parsed_event, context) + return handler(parsed_event, context, **kwargs) @overload @@ -104,7 +109,7 @@ def parse(event: Dict[str, Any], model: Type[Model]) -> Model: @overload -def parse(event: Dict[str, Any], model: Type[Model], envelope: Type[Envelope]): +def parse(event: Dict[str, Any], model: Type[Model], envelope: Type[Envelope]) -> Model: ... # pragma: no cover diff --git a/aws_lambda_powertools/utilities/validation/validator.py b/aws_lambda_powertools/utilities/validation/validator.py index fdd92de3687..968656ee49c 100644 --- a/aws_lambda_powertools/utilities/validation/validator.py +++ b/aws_lambda_powertools/utilities/validation/validator.py @@ -20,6 +20,7 @@ def validator( outbound_formats: Optional[Dict] = None, envelope: str = "", jmespath_options: Optional[Dict] = None, + **kwargs: Any, ) -> Any: """Lambda handler decorator to validate incoming/outbound data using a JSON Schema @@ -128,7 +129,7 @@ def handler(event, context): logger.debug("Validating inbound event") validate_data_against_schema(data=event, schema=inbound_schema, formats=inbound_formats) - response = handler(event, context) + response = handler(event, context, **kwargs) if outbound_schema: logger.debug("Validating outbound event")