diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 8a552b53bcb..274cd96aace 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from collections import namedtuple from datetime import datetime, timedelta -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union from .exceptions import GetParameterError, TransformParameterError @@ -31,6 +31,9 @@ def __init__(self): self.store = {} + def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool: + return key in self.store and self.store[key].ttl >= datetime.now() + def get( self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options ) -> Union[str, list, dict, bytes]: @@ -70,24 +73,26 @@ def get( # an acceptable tradeoff. key = (name, transform) - if key not in self.store or self.store[key].ttl < datetime.now(): - try: - value = self._get(name, **sdk_options) - # Encapsulate all errors into a generic GetParameterError - except Exception as exc: - raise GetParameterError(str(exc)) + if self._has_not_expired(key): + return self.store[key].value + + try: + value = self._get(name, **sdk_options) + # Encapsulate all errors into a generic GetParameterError + except Exception as exc: + raise GetParameterError(str(exc)) - if transform is not None: - value = transform_value(value, transform) + if transform is not None: + value = transform_value(value, transform) - self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),) + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),) - return self.store[key].value + return value @abstractmethod def _get(self, name: str, **sdk_options) -> str: """ - Retrieve paramater value from the underlying parameter store + Retrieve parameter value from the underlying parameter store """ raise NotImplementedError() @@ -129,29 +134,22 @@ def get_multiple( key = (path, transform) - if key not in self.store or self.store[key].ttl < datetime.now(): - try: - values = self._get_multiple(path, **sdk_options) - # Encapsulate all errors into a generic GetParameterError - except Exception as exc: - raise GetParameterError(str(exc)) + if self._has_not_expired(key): + return self.store[key].value - if transform is not None: - new_values = {} - for key, value in values.items(): - try: - new_values[key] = transform_value(value, transform) - except Exception as exc: - if raise_on_transform_error: - raise exc - else: - new_values[key] = None + try: + values: Dict[str, Union[str, bytes, dict, None]] = self._get_multiple(path, **sdk_options) + # Encapsulate all errors into a generic GetParameterError + except Exception as exc: + raise GetParameterError(str(exc)) - values = new_values + if transform is not None: + for (key, value) in values.items(): + values[key] = transform_value(value, transform, raise_on_transform_error) - self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) + self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) - return self.store[key].value + return values @abstractmethod def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: @@ -161,16 +159,19 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: raise NotImplementedError() -def transform_value(value: str, transform: str) -> Union[dict, bytes]: +def transform_value(value: str, transform: str, raise_on_transform_error: bool = True) -> Union[dict, bytes, None]: """ Apply a transform to a value Parameters --------- value: str - Parameter alue to transform + Parameter value to transform transform: str Type of transform, supported values are "json" and "binary" + raise_on_transform_error: bool, optional + Raises an exception if any transform fails, otherwise this will + return a None value for each transform that failed Raises ------ @@ -187,4 +188,6 @@ def transform_value(value: str, transform: str) -> Union[dict, bytes]: raise ValueError(f"Invalid transform type '{transform}'") except Exception as exc: - raise TransformParameterError(str(exc)) + if raise_on_transform_error: + raise TransformParameterError(str(exc)) + return None diff --git a/aws_lambda_powertools/utilities/parameters/secrets.py b/aws_lambda_powertools/utilities/parameters/secrets.py index ee4585309fe..67cb94c340b 100644 --- a/aws_lambda_powertools/utilities/parameters/secrets.py +++ b/aws_lambda_powertools/utilities/parameters/secrets.py @@ -77,7 +77,7 @@ def _get(self, name: str, **sdk_options) -> str: ---------- name: str Name of the parameter - sdk_options: dict + sdk_options: dict, optional Dictionary of options that will be passed to the Secrets Manager get_secret_value API call """ diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py index b458f8690d0..0f39bfac9c0 100644 --- a/aws_lambda_powertools/utilities/parameters/ssm.py +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -8,7 +8,7 @@ import boto3 from botocore.config import Config -from .base import DEFAULT_PROVIDERS, BaseProvider +from .base import DEFAULT_MAX_AGE_SECS, DEFAULT_PROVIDERS, BaseProvider class SSMProvider(BaseProvider): @@ -86,6 +86,46 @@ def __init__( super().__init__() + def get( + self, + name: str, + max_age: int = DEFAULT_MAX_AGE_SECS, + transform: Optional[str] = None, + decrypt: bool = False, + **sdk_options + ) -> Union[str, list, dict, bytes]: + """ + Retrieve a parameter value or return the cached value + + Parameters + ---------- + name: str + Parameter name + max_age: int + Maximum age of the cached value + transform: str + Optional transformation of the parameter value. Supported values + are "json" for JSON strings and "binary" for base 64 encoded + values. + decrypt: bool, optional + If the parameter value should be decrypted + sdk_options: dict, optional + Arguments that will be passed directly to the underlying API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + TransformParameterError + When the parameter provider fails to transform a parameter value. + """ + + # Add to `decrypt` sdk_options to we can have an explicit option for this + sdk_options["decrypt"] = decrypt + + return super().get(name, max_age, transform, **sdk_options) + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: """ Retrieve a parameter value from AWS Systems Manager Parameter Store @@ -144,7 +184,9 @@ def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = Fals return parameters -def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, list, dict, bytes]: +def get_parameter( + name: str, transform: Optional[str] = None, decrypt: bool = False, **sdk_options +) -> Union[str, list, dict, bytes]: """ Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store @@ -154,6 +196,8 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Name of the parameter transform: str, optional Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameter API call @@ -190,7 +234,10 @@ def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> if "ssm" not in DEFAULT_PROVIDERS: DEFAULT_PROVIDERS["ssm"] = SSMProvider() - return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform) + # Add to `decrypt` sdk_options to we can have an explicit option for this + sdk_options["decrypt"] = decrypt + + return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform, **sdk_options) def get_parameters( @@ -205,10 +252,10 @@ def get_parameters( Path to retrieve the parameters transform: str, optional Transforms the content from a JSON object ('json') or base64 binary string ('binary') - decrypt: bool, optional - If the parameter values should be decrypted recursive: bool, optional If this should retrieve the parameter values recursively or not, defaults to True + decrypt: bool, optional + If the parameter values should be decrypted sdk_options: dict, optional Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call @@ -245,4 +292,7 @@ def get_parameters( if "ssm" not in DEFAULT_PROVIDERS: DEFAULT_PROVIDERS["ssm"] = SSMProvider() - return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, recursive=recursive, decrypt=decrypt) + sdk_options["recursive"] = recursive + sdk_options["decrypt"] = decrypt + + return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, **sdk_options) diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 7a0677b2197..abd121540a6 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1310,6 +1310,7 @@ def test_get_parameter_new(monkeypatch, mock_name, mock_value): class TestProvider(BaseProvider): def _get(self, name: str, **kwargs) -> str: assert name == mock_name + assert not kwargs["decrypt"] return mock_value def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: @@ -1355,6 +1356,8 @@ def _get(self, name: str, **kwargs) -> str: def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: assert path == mock_name + assert kwargs["recursive"] + assert not kwargs["decrypt"] return mock_value monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {}) @@ -1468,3 +1471,13 @@ def test_transform_value_wrong(mock_value): parameters.base.transform_value(mock_value, "INCORRECT") assert "Invalid transform type" in str(excinfo) + + +def test_transform_value_ignore_error(mock_value): + """ + Test transform_value() does not raise errors when raise_on_transform_error is False + """ + + value = parameters.base.transform_value(mock_value, "INCORRECT", raise_on_transform_error=False) + + assert value is None