diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..bfef1fe87f --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/python/black + rev: 19.3b0 + hooks: + - id: black + language_version: python3.7 + exclude_types: ['markdown', 'ini', 'toml', 'rst'] diff --git a/.travis.yml b/.travis.yml index 4e4b1c0085..58c7e58489 100644 --- a/.travis.yml +++ b/.travis.yml @@ -21,6 +21,12 @@ matrix: install: # Install the code requirements +- mkdir $HOME/bin-black +- wget -O $HOME/bin-black/black https://github.com/python/black/releases/download/19.3b0/black +- chmod +x $HOME/bin-black/black +- export PATH=$PATH:$HOME/bin-black +- black --version + - make init # Install Docs requirements diff --git a/DEVELOPMENT_GUIDE.rst b/DEVELOPMENT_GUIDE.rst index b0be211784..d04df0b3ec 100755 --- a/DEVELOPMENT_GUIDE.rst +++ b/DEVELOPMENT_GUIDE.rst @@ -26,8 +26,23 @@ Setup Python locally using `pyenv`_ #. ``pyenv install 2.7.14`` #. Make the Python version available in the project: ``pyenv local 2.7.14`` +2. Install Additional Tooling +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +1. Black +~~~~~~~~ +We format our code using [Black](https://github.com/python/black) and verify the source code is black compliant +in Appveyor during PRs. You can find installation instructions on [Black's docs](https://black.readthedocs.io/en/stable/installation_and_usage.html). -2. Activate Virtualenv +After installing, you can run our formatting through our Makefile by `make black-format` or integrating Black directly in your favorite IDE (instructions +can be found [here](https://black.readthedocs.io/en/stable/editor_integration.html)) + +Pre-commit +~~~~~~~~~~ +If you don't wish to manually run black on each pr or install black manually, we have integrated black into git hooks through [pre-commit](https://pre-commit.com/). +After installing pre-commit, run `pre-commit install` in the root of the project. This will install black for you and run the black formatting on +commit. + +3. Activate Virtualenv ~~~~~~~~~~~~~~~~~~~~~~ Virtualenv allows you to install required libraries outside of the Python installation. A good practice is to setup a different virtualenv for each project. `pyenv`_ comes with a handy plugin that can create virtualenv. @@ -37,7 +52,7 @@ a different virtualenv for each project. `pyenv`_ comes with a handy plugin that #. [Optional] Automatically activate the virtualenv in for this folder: ``pyenv local samtranslator27`` -3. Install dependencies +4. Install dependencies ~~~~~~~~~~~~~~~~~~~~~~~ Install dependencies by running the following command. Make sure the Virtualenv you created above is active. diff --git a/Makefile b/Makefile index f4a0fc9b4d..314446985f 100755 --- a/Makefile +++ b/Makefile @@ -39,19 +39,21 @@ init: $(info [*] Install requirements...) @pip install -r requirements/dev.txt -r requirements/base.txt -flake: - $(info [*] Running flake8...) - @flake8 samtranslator - test: $(info [*] Run the unit test with minimum code coverage of $(CODE_COVERAGE)%...) @pytest --cov samtranslator --cov-report term-missing --cov-fail-under $(CODE_COVERAGE) tests +black: + black setup.py samtranslator/* tests/* bin/* + +black-check: + black --check setup.py samtranslator/* tests/* bin/* + # Command to run everytime you make changes to verify everything works -dev: flake test +dev: test # Verifications to run before sending a pull request -pr: init dev +pr: black-check init dev define HELP_MESSAGE @@ -64,7 +66,6 @@ TARGETS init Initialize and install the requirements and dev-requirements for this project. test Run the Unit tests. dev Run all development tests after a change. - build-docs Generate the documentation. pr Perform all checks before submitting a Pull Request. endef diff --git a/bin/sam-translate.py b/bin/sam-translate.py index 2483bd4e40..5c0c7ede2f 100755 --- a/bin/sam-translate.py +++ b/bin/sam-translate.py @@ -29,7 +29,7 @@ from docopt import docopt my_path = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, my_path + '/..') +sys.path.insert(0, my_path + "/..") from samtranslator.public.translator import ManagedPolicyLoader from samtranslator.translator.transform import transform @@ -38,18 +38,19 @@ LOG = logging.getLogger(__name__) cli_options = docopt(__doc__) -iam_client = boto3.client('iam') +iam_client = boto3.client("iam") cwd = os.getcwd() -if cli_options.get('--verbose'): +if cli_options.get("--verbose"): logging.basicConfig(level=logging.DEBUG) else: logging.basicConfig() + def execute_command(command, args): try: - aws_cmd = 'aws' if platform.system().lower() != 'windows' else 'aws.cmd' - command_with_args = [aws_cmd, 'cloudformation', command] + list(args) + aws_cmd = "aws" if platform.system().lower() != "windows" else "aws.cmd" + command_with_args = [aws_cmd, "cloudformation", command] + list(args) LOG.debug("Executing command: %s", command_with_args) @@ -63,8 +64,8 @@ def execute_command(command, args): def get_input_output_file_paths(): - input_file_option = cli_options.get('--template-file') - output_file_option = cli_options.get('--output-template') + input_file_option = cli_options.get("--template-file") + output_file_option = cli_options.get("--output-template") input_file_path = os.path.join(cwd, input_file_option) output_file_path = os.path.join(cwd, output_file_option) @@ -73,67 +74,58 @@ def get_input_output_file_paths(): def package(input_file_path, output_file_path): template_file = input_file_path - package_output_template_file = input_file_path + '._sam_packaged_.yaml' - s3_bucket = cli_options.get('--s3-bucket') + package_output_template_file = input_file_path + "._sam_packaged_.yaml" + s3_bucket = cli_options.get("--s3-bucket") args = [ - '--template-file', + "--template-file", template_file, - '--output-template-file', + "--output-template-file", package_output_template_file, - '--s3-bucket', - s3_bucket + "--s3-bucket", + s3_bucket, ] - execute_command('package', args) + execute_command("package", args) return package_output_template_file def transform_template(input_file_path, output_file_path): - with open(input_file_path, 'r') as f: + with open(input_file_path, "r") as f: sam_template = yaml_parse(f) try: - cloud_formation_template = transform( - sam_template, {}, ManagedPolicyLoader(iam_client)) - cloud_formation_template_prettified = json.dumps( - cloud_formation_template, indent=2) + cloud_formation_template = transform(sam_template, {}, ManagedPolicyLoader(iam_client)) + cloud_formation_template_prettified = json.dumps(cloud_formation_template, indent=2) - with open(output_file_path, 'w') as f: + with open(output_file_path, "w") as f: f.write(cloud_formation_template_prettified) - print('Wrote transformed CloudFormation template to: ' + output_file_path) + print ("Wrote transformed CloudFormation template to: " + output_file_path) except InvalidDocumentException as e: - errorMessage = reduce(lambda message, error: message + ' ' + error.message, e.causes, e.message) + errorMessage = reduce(lambda message, error: message + " " + error.message, e.causes, e.message) LOG.error(errorMessage) errors = map(lambda cause: cause.message, e.causes) LOG.error(errors) def deploy(template_file): - capabilities = cli_options.get('--capabilities') - stack_name = cli_options.get('--stack-name') - args = [ - '--template-file', - template_file, - '--capabilities', - capabilities, - '--stack-name', - stack_name - ] + capabilities = cli_options.get("--capabilities") + stack_name = cli_options.get("--stack-name") + args = ["--template-file", template_file, "--capabilities", capabilities, "--stack-name", stack_name] - execute_command('deploy', args) + execute_command("deploy", args) return package_output_template_file -if __name__ == '__main__': +if __name__ == "__main__": input_file_path, output_file_path = get_input_output_file_paths() - if cli_options.get('package'): + if cli_options.get("package"): package_output_template_file = package(input_file_path, output_file_path) transform_template(package_output_template_file, output_file_path) - elif cli_options.get('deploy'): + elif cli_options.get("deploy"): package_output_template_file = package(input_file_path, output_file_path) transform_template(package_output_template_file, output_file_path) deploy(output_file_path) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..35465959ef --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[tool.black] +line-length = 120 +target_version = ['py27', 'py37', 'py36', 'py38'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.tox + | \.venv + | build + | dist + | pip-wheel-metadata + | examples + )/ +) +''' diff --git a/samtranslator/__init__.py b/samtranslator/__init__.py index bec073807b..0dcddbc87b 100644 --- a/samtranslator/__init__.py +++ b/samtranslator/__init__.py @@ -1 +1 @@ -__version__ = '1.20.0' +__version__ = "1.20.0" diff --git a/samtranslator/intrinsics/actions.py b/samtranslator/intrinsics/actions.py index d535305ed2..057bae032c 100644 --- a/samtranslator/intrinsics/actions.py +++ b/samtranslator/intrinsics/actions.py @@ -46,10 +46,12 @@ def can_handle(self, input_dict): :return: True if it matches expected structure, False otherwise """ - return input_dict is not None \ - and isinstance(input_dict, dict) \ - and len(input_dict) == 1 \ + return ( + input_dict is not None + and isinstance(input_dict, dict) + and len(input_dict) == 1 and self.intrinsic_name in input_dict + ) @classmethod def _parse_resource_reference(cls, ref_value): @@ -132,9 +134,7 @@ def resolve_resource_refs(self, input_dict, supported_resource_refs): if not resolved_value: return input_dict - return { - self.intrinsic_name: resolved_value - } + return {self.intrinsic_name: resolved_value} def resolve_resource_id_refs(self, input_dict, supported_resource_id_refs): """ @@ -161,9 +161,7 @@ def resolve_resource_id_refs(self, input_dict, supported_resource_id_refs): if not resolved_value: return input_dict - return { - self.intrinsic_name: resolved_value - } + return {self.intrinsic_name: resolved_value} class SubAction(Action): @@ -372,16 +370,18 @@ def handler_method(full_ref, ref_value): """ # RegExp to find pattern "${logicalId.property}" and return the word inside bracket - logical_id_regex = '[A-Za-z0-9\.]+|AWS::[A-Z][A-Za-z]*' - ref_pattern = re.compile(r'\$\{(' + logical_id_regex + ')\}') + logical_id_regex = "[A-Za-z0-9\.]+|AWS::[A-Z][A-Za-z]*" + ref_pattern = re.compile(r"\$\{(" + logical_id_regex + ")\}") # Find all the pattern, and call the handler to decide how to substitute them. # Do the substitution and return the final text - return re.sub(ref_pattern, - # Pass the handler entire string ${logicalId.property} as first parameter and "logicalId.property" - # as second parameter. Return value will be substituted - lambda match: handler_method(match.group(0), match.group(1)), - text) + return re.sub( + ref_pattern, + # Pass the handler entire string ${logicalId.property} as first parameter and "logicalId.property" + # as second parameter. Return value will be substituted + lambda match: handler_method(match.group(0), match.group(1)), + text, + ) class GetAttAction(Action): @@ -427,10 +427,14 @@ def resolve_resource_refs(self, input_dict, supported_resource_refs): if not isinstance(value, list) or len(value) < 2: return input_dict - if (not all(isinstance(entry, string_types) for entry in value)): + if not all(isinstance(entry, string_types) for entry in value): raise InvalidDocumentException( - [InvalidTemplateException('Invalid GetAtt value {}. GetAtt expects an array with 2 strings.' - .format(value))]) + [ + InvalidTemplateException( + "Invalid GetAtt value {}. GetAtt expects an array with 2 strings.".format(value) + ) + ] + ) # Value of GetAtt is an array. It can contain any number of elements, with first being the LogicalId of # resource and rest being the attributes. In a SAM template, a reference to a resource can be used in the @@ -515,6 +519,7 @@ class FindInMapAction(Action): """ This action can't be used along with other actions. """ + intrinsic_name = "Fn::FindInMap" def resolve_parameter_refs(self, input_dict, parameters): @@ -535,21 +540,29 @@ def resolve_parameter_refs(self, input_dict, parameters): # FindInMap expects an array with 3 values if not isinstance(value, list) or len(value) != 3: raise InvalidDocumentException( - [InvalidTemplateException('Invalid FindInMap value {}. FindInMap expects an array with 3 values.' - .format(value))]) + [ + InvalidTemplateException( + "Invalid FindInMap value {}. FindInMap expects an array with 3 values.".format(value) + ) + ] + ) map_name = self.resolve_parameter_refs(value[0], parameters) top_level_key = self.resolve_parameter_refs(value[1], parameters) second_level_key = self.resolve_parameter_refs(value[2], parameters) - if not isinstance(map_name, string_types) or \ - not isinstance(top_level_key, string_types) or \ - not isinstance(second_level_key, string_types): + if ( + not isinstance(map_name, string_types) + or not isinstance(top_level_key, string_types) + or not isinstance(second_level_key, string_types) + ): return input_dict - if map_name not in parameters or \ - top_level_key not in parameters[map_name] or \ - second_level_key not in parameters[map_name][top_level_key]: + if ( + map_name not in parameters + or top_level_key not in parameters[map_name] + or second_level_key not in parameters[map_name][top_level_key] + ): return input_dict return parameters[map_name][top_level_key][second_level_key] diff --git a/samtranslator/intrinsics/resolver.py b/samtranslator/intrinsics/resolver.py index f7fe275300..0a2c237877 100644 --- a/samtranslator/intrinsics/resolver.py +++ b/samtranslator/intrinsics/resolver.py @@ -7,7 +7,6 @@ class IntrinsicsResolver(object): - def __init__(self, parameters, supported_intrinsics=DEFAULT_SUPPORTED_INTRINSICS): """ Instantiate the resolver @@ -20,8 +19,9 @@ def __init__(self, parameters, supported_intrinsics=DEFAULT_SUPPORTED_INTRINSICS if parameters is None or not isinstance(parameters, dict): raise TypeError("parameters must be a valid dictionary") - if not isinstance(supported_intrinsics, dict) \ - or not all([isinstance(value, Action) for value in supported_intrinsics.values()]): + if not isinstance(supported_intrinsics, dict) or not all( + [isinstance(value, Action) for value in supported_intrinsics.values()] + ): raise TypeError("supported_intrinsics argument must be intrinsic names to corresponding Action classes") self.supported_intrinsics = supported_intrinsics @@ -217,6 +217,4 @@ def _is_intrinsic_dict(self, input): :return: True, if the input contains a supported intrinsic function. False otherwise """ # All intrinsic functions are dictionaries with just one key - return isinstance(input, dict) \ - and len(input) == 1 \ - and list(input.keys())[0] in self.supported_intrinsics + return isinstance(input, dict) and len(input) == 1 and list(input.keys())[0] in self.supported_intrinsics diff --git a/samtranslator/model/__init__.py b/samtranslator/model/__init__.py index e60e959e04..ed8dff9142 100644 --- a/samtranslator/model/__init__.py +++ b/samtranslator/model/__init__.py @@ -17,6 +17,7 @@ class PropertyType(object): :ivar supports_intrinsics True to allow intrinsic function support on this property. Setting this to False will raise an error when intrinsic function dictionary is supplied as value """ + def __init__(self, required, validate=lambda value: True, supports_intrinsics=True): self.required = required self.validate = validate @@ -37,9 +38,10 @@ class Resource(object): which indicate whether the property is required and the property's type. Properties that are not in this dict will \ be considered invalid. """ + resource_type = None property_types = None - _keywords = ['logical_id', 'relative_id', "depends_on", "resource_attributes"] + _keywords = ["logical_id", "relative_id", "depends_on", "resource_attributes"] _supported_resource_attributes = ["DeletionPolicy", "UpdatePolicy", "Condition"] @@ -113,8 +115,8 @@ def from_dict(cls, logical_id, resource_dict, relative_id=None, sam_plugins=None for name, value in properties.items(): setattr(resource, name, value) - if 'DependsOn' in resource_dict: - resource.depends_on = resource_dict['DependsOn'] + if "DependsOn" in resource_dict: + resource.depends_on = resource_dict["DependsOn"] # Parse only well known properties. This is consistent with earlier behavior where we used to ignore resource # all resource attributes ie. all attributes were unsupported before @@ -134,7 +136,7 @@ def _validate_logical_id(cls, logical_id): :rtype: bool :raises TypeError: if the logical id is invalid """ - pattern = re.compile(r'^[A-Za-z0-9]+$') + pattern = re.compile(r"^[A-Za-z0-9]+$") if logical_id is not None and pattern.match(logical_id): return True raise InvalidResourceException(logical_id, "Logical ids must be alphanumeric.") @@ -148,15 +150,16 @@ def _validate_resource_dict(cls, logical_id, resource_dict): :rtype: bool :raises InvalidResourceException: if the resource dict has an invalid format """ - if 'Type' not in resource_dict: + if "Type" not in resource_dict: raise InvalidResourceException(logical_id, "Resource dict missing key 'Type'.") - if resource_dict['Type'] != cls.resource_type: - raise InvalidResourceException(logical_id, "Resource has incorrect Type; expected '{expected}', " - "got '{actual}'".format( - expected=cls.resource_type, - actual=resource_dict['Type'])) - - if 'Properties' in resource_dict and not isinstance(resource_dict['Properties'], dict): + if resource_dict["Type"] != cls.resource_type: + raise InvalidResourceException( + logical_id, + "Resource has incorrect Type; expected '{expected}', " + "got '{actual}'".format(expected=cls.resource_type, actual=resource_dict["Type"]), + ) + + if "Properties" in resource_dict and not isinstance(resource_dict["Properties"], dict): raise InvalidResourceException(logical_id, "Properties of a resource must be an object.") def to_dict(self): @@ -195,10 +198,10 @@ def _generate_resource_dict(self): """ resource_dict = {} - resource_dict['Type'] = self.resource_type + resource_dict["Type"] = self.resource_type if self.depends_on: - resource_dict['DependsOn'] = self.depends_on + resource_dict["DependsOn"] = self.depends_on resource_dict.update(self.resource_attributes) @@ -208,7 +211,7 @@ def _generate_resource_dict(self): if value is not None: properties_dict[name] = value - resource_dict['Properties'] = properties_dict + resource_dict["Properties"] = properties_dict return resource_dict @@ -223,9 +226,12 @@ def __setattr__(self, name, value): if name in self._keywords or name in self.property_types: return super(Resource, self).__setattr__(name, value) - raise InvalidResourceException(self.logical_id, - "property {property_name} not defined for resource of type {resource_type}" - .format(resource_type=self.resource_type, property_name=name)) + raise InvalidResourceException( + self.logical_id, + "property {property_name} not defined for resource of type {resource_type}".format( + resource_type=self.resource_type, property_name=name + ), + ) def validate_properties(self): """Validates that the required properties for this Resource have been populated, and that all properties have @@ -246,13 +252,13 @@ def validate_properties(self): if value is None: if property_type.required: raise InvalidResourceException( - self.logical_id, - "Missing required property '{property_name}'.".format(property_name=name)) + self.logical_id, "Missing required property '{property_name}'.".format(property_name=name) + ) # Otherwise, validate the value of the property. elif not property_type.validate(value, should_raise=False): raise InvalidResourceException( - self.logical_id, - "Type of property '{property_name}' is invalid.".format(property_name=name)) + self.logical_id, "Type of property '{property_name}' is invalid.".format(property_name=name) + ) def set_resource_attribute(self, attr, value): """Sets attributes on resource. Resource attributes are top-level entries of a CloudFormation resource @@ -313,8 +319,8 @@ def get_passthrough_resource_attributes(self): :return: Dictionary of resource attributes. """ attributes = None - if 'Condition' in self.resource_attributes: - attributes = {'Condition': self.resource_attributes['Condition']} + if "Condition" in self.resource_attributes: + attributes = {"Condition": self.resource_attributes["Condition"]} return attributes @@ -366,19 +372,15 @@ class SamResourceMacro(ResourceMacro): referable_properties = {} # Each resource can optionally override this tag: - _SAM_KEY = 'lambda:createdBy' - _SAM_VALUE = 'SAM' + _SAM_KEY = "lambda:createdBy" + _SAM_VALUE = "SAM" # Tags reserved by the serverless application repo - _SAR_APP_KEY = 'serverlessrepo:applicationId' - _SAR_SEMVER_KEY = 'serverlessrepo:semanticVersion' + _SAR_APP_KEY = "serverlessrepo:applicationId" + _SAR_SEMVER_KEY = "serverlessrepo:semanticVersion" # Aggregate list of all reserved tags - _RESERVED_TAGS = [ - _SAM_KEY, - _SAR_APP_KEY, - _SAR_SEMVER_KEY - ] + _RESERVED_TAGS = [_SAM_KEY, _SAR_APP_KEY, _SAR_SEMVER_KEY] def get_resource_references(self, generated_cfn_resources, supported_resource_refs): """ @@ -425,10 +427,13 @@ def _construct_tag_list(self, tags, additional_tags=None): def _check_tag(self, reserved_tag_name, tags): if reserved_tag_name in tags: - raise InvalidResourceException(self.logical_id, reserved_tag_name + " is a reserved Tag key name and " - "cannot be set on your resource. " - "Please change the tag key in the " - "input.") + raise InvalidResourceException( + self.logical_id, + reserved_tag_name + " is a reserved Tag key name and " + "cannot be set on your resource. " + "Please change the tag key in the " + "input.", + ) def _resolve_string_parameter(self, intrinsics_resolver, parameter_value, parameter_name): if not parameter_value: @@ -436,9 +441,10 @@ def _resolve_string_parameter(self, intrinsics_resolver, parameter_value, parame value = intrinsics_resolver.resolve_parameter_refs(parameter_value) if not isinstance(value, string_types) and not isinstance(value, dict): - raise InvalidResourceException(self.logical_id, - "Could not resolve parameter for '{}' or parameter is not a String." - .format(parameter_name)) + raise InvalidResourceException( + self.logical_id, + "Could not resolve parameter for '{}' or parameter is not a String.".format(parameter_name), + ) return value @@ -454,17 +460,19 @@ def __init__(self, *modules): self.resource_types = {} for module in modules: # Get all classes in the specified module which have a class variable resource_type. - for _, resource_class in inspect.getmembers(module, - lambda cls: inspect.isclass(cls) and - cls.__module__ == module.__name__ and - hasattr(cls, 'resource_type')): + for _, resource_class in inspect.getmembers( + module, + lambda cls: inspect.isclass(cls) + and cls.__module__ == module.__name__ + and hasattr(cls, "resource_type"), + ): self.resource_types[resource_class.resource_type] = resource_class def can_resolve(self, resource_dict): - if not isinstance(resource_dict, dict) or 'Type' not in resource_dict: + if not isinstance(resource_dict, dict) or "Type" not in resource_dict: return False - return resource_dict['Type'] in self.resource_types + return resource_dict["Type"] in self.resource_types def resolve_resource_type(self, resource_dict): """Returns the Resource class corresponding to the 'Type' key in the given resource dict. @@ -474,8 +482,11 @@ def resolve_resource_type(self, resource_dict): :rtype: class """ if not self.can_resolve(resource_dict): - raise TypeError("Resource dict has missing or invalid value for key Type. Event Type is: {}.".format( - resource_dict.get('Type'))) - if resource_dict['Type'] not in self.resource_types: - raise TypeError("Invalid resource type {resource_type}".format(resource_type=resource_dict['Type'])) - return self.resource_types[resource_dict['Type']] + raise TypeError( + "Resource dict has missing or invalid value for key Type. Event Type is: {}.".format( + resource_dict.get("Type") + ) + ) + if resource_dict["Type"] not in self.resource_types: + raise TypeError("Invalid resource type {resource_type}".format(resource_type=resource_dict["Type"])) + return self.resource_types[resource_dict["Type"]] diff --git a/samtranslator/model/api/api_generator.py b/samtranslator/model/api/api_generator.py index 26c831ec22..502bb3ba1c 100644 --- a/samtranslator/model/api/api_generator.py +++ b/samtranslator/model/api/api_generator.py @@ -1,10 +1,15 @@ from collections import namedtuple from six import string_types from samtranslator.model.intrinsics import ref, fnGetAtt -from samtranslator.model.apigateway import (ApiGatewayDeployment, ApiGatewayRestApi, - ApiGatewayStage, ApiGatewayAuthorizer, - ApiGatewayResponse, ApiGatewayDomainName, - ApiGatewayBasePathMapping) +from samtranslator.model.apigateway import ( + ApiGatewayDeployment, + ApiGatewayRestApi, + ApiGatewayStage, + ApiGatewayAuthorizer, + ApiGatewayResponse, + ApiGatewayDomainName, + ApiGatewayBasePathMapping, +) from samtranslator.model.route53 import Route53RecordSetGroup from samtranslator.model.exceptions import InvalidResourceException from samtranslator.model.s3_utils.uri_parser import parse_s3_uri @@ -17,27 +22,57 @@ from samtranslator.model.tags.resource_tagging import get_tag_list _CORS_WILDCARD = "'*'" -CorsProperties = namedtuple("_CorsProperties", ["AllowMethods", "AllowHeaders", "AllowOrigin", "MaxAge", - "AllowCredentials"]) +CorsProperties = namedtuple( + "_CorsProperties", ["AllowMethods", "AllowHeaders", "AllowOrigin", "MaxAge", "AllowCredentials"] +) # Default the Cors Properties to '*' wildcard and False AllowCredentials. Other properties are actually Optional CorsProperties.__new__.__defaults__ = (None, None, _CORS_WILDCARD, None, False) -AuthProperties = namedtuple("_AuthProperties", - ["Authorizers", "DefaultAuthorizer", "InvokeRole", "AddDefaultAuthorizerToCorsPreflight", - "ApiKeyRequired", "ResourcePolicy"]) +AuthProperties = namedtuple( + "_AuthProperties", + [ + "Authorizers", + "DefaultAuthorizer", + "InvokeRole", + "AddDefaultAuthorizerToCorsPreflight", + "ApiKeyRequired", + "ResourcePolicy", + ], +) AuthProperties.__new__.__defaults__ = (None, None, None, True, None, None) GatewayResponseProperties = ["ResponseParameters", "ResponseTemplates", "StatusCode"] class ApiGenerator(object): - - def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variables, depends_on, - definition_body, definition_uri, name, stage_name, tags=None, endpoint_configuration=None, - method_settings=None, binary_media=None, minimum_compression_size=None, cors=None, - auth=None, gateway_responses=None, access_log_setting=None, canary_setting=None, - tracing_enabled=None, resource_attributes=None, passthrough_resource_attributes=None, - open_api_version=None, models=None, domain=None): + def __init__( + self, + logical_id, + cache_cluster_enabled, + cache_cluster_size, + variables, + depends_on, + definition_body, + definition_uri, + name, + stage_name, + tags=None, + endpoint_configuration=None, + method_settings=None, + binary_media=None, + minimum_compression_size=None, + cors=None, + auth=None, + gateway_responses=None, + access_log_setting=None, + canary_setting=None, + tracing_enabled=None, + resource_attributes=None, + passthrough_resource_attributes=None, + open_api_version=None, + models=None, + domain=None, + ): """Constructs an API Generator class that generates API Gateway resources :param logical_id: Logical id of the SAM API Resource @@ -105,14 +140,17 @@ def _construct_rest_api(self): self._set_endpoint_configuration(rest_api, "REGIONAL") if self.definition_uri and self.definition_body: - raise InvalidResourceException(self.logical_id, - "Specify either 'DefinitionUri' or 'DefinitionBody' property and not both") + raise InvalidResourceException( + self.logical_id, "Specify either 'DefinitionUri' or 'DefinitionBody' property and not both" + ) if self.open_api_version: - if not SwaggerEditor.safe_compare_regex_with_string(SwaggerEditor.get_openapi_versions_supported_regex(), - self.open_api_version): - raise InvalidResourceException(self.logical_id, - "The OpenApiVersion value must be of the format \"3.0.0\"") + if not SwaggerEditor.safe_compare_regex_with_string( + SwaggerEditor.get_openapi_versions_supported_regex(), self.open_api_version + ): + raise InvalidResourceException( + self.logical_id, 'The OpenApiVersion value must be of the format "3.0.0"' + ) self._add_cors() self._add_auth() @@ -141,8 +179,9 @@ def _construct_body_s3_dict(self): if isinstance(self.definition_uri, dict): if not self.definition_uri.get("Bucket", None) or not self.definition_uri.get("Key", None): # DefinitionUri is a dictionary but does not contain Bucket or Key property - raise InvalidResourceException(self.logical_id, - "'DefinitionUri' requires Bucket and Key properties to be specified") + raise InvalidResourceException( + self.logical_id, "'DefinitionUri' requires Bucket and Key properties to be specified" + ) s3_pointer = self.definition_uri else: @@ -150,16 +189,15 @@ def _construct_body_s3_dict(self): # DefinitionUri is a string s3_pointer = parse_s3_uri(self.definition_uri) if s3_pointer is None: - raise InvalidResourceException(self.logical_id, - '\'DefinitionUri\' is not a valid S3 Uri of the form ' - '"s3://bucket/key" with optional versionId query parameter.') - - body_s3 = { - 'Bucket': s3_pointer['Bucket'], - 'Key': s3_pointer['Key'] - } - if 'Version' in s3_pointer: - body_s3['Version'] = s3_pointer['Version'] + raise InvalidResourceException( + self.logical_id, + "'DefinitionUri' is not a valid S3 Uri of the form " + '"s3://bucket/key" with optional versionId query parameter.', + ) + + body_s3 = {"Bucket": s3_pointer["Bucket"], "Key": s3_pointer["Key"]} + if "Version" in s3_pointer: + body_s3["Version"] = s3_pointer["Version"] return body_s3 def _construct_deployment(self, rest_api): @@ -169,11 +207,12 @@ def _construct_deployment(self, rest_api): :returns: the Deployment to which this SAM Api corresponds :rtype: model.apigateway.ApiGatewayDeployment """ - deployment = ApiGatewayDeployment(self.logical_id + 'Deployment', - attributes=self.passthrough_resource_attributes) - deployment.RestApiId = rest_api.get_runtime_attr('rest_api_id') + deployment = ApiGatewayDeployment( + self.logical_id + "Deployment", attributes=self.passthrough_resource_attributes + ) + deployment.RestApiId = rest_api.get_runtime_attr("rest_api_id") if not self.remove_extra_stage: - deployment.StageName = 'Stage' + deployment.StageName = "Stage" return deployment @@ -189,12 +228,11 @@ def _construct_stage(self, deployment, swagger): # This will NOT create duplicates because we allow only ONE stage per API resource stage_name_prefix = self.stage_name if isinstance(self.stage_name, string_types) else "" if stage_name_prefix.isalnum(): - stage_logical_id = self.logical_id + stage_name_prefix + 'Stage' + stage_logical_id = self.logical_id + stage_name_prefix + "Stage" else: - generator = logical_id_generator.LogicalIdGenerator(self.logical_id + 'Stage', stage_name_prefix) + generator = logical_id_generator.LogicalIdGenerator(self.logical_id + "Stage", stage_name_prefix) stage_logical_id = generator.gen() - stage = ApiGatewayStage(stage_logical_id, - attributes=self.passthrough_resource_attributes) + stage = ApiGatewayStage(stage_logical_id, attributes=self.passthrough_resource_attributes) stage.RestApiId = ref(self.logical_id) stage.update_deployment_ref(deployment.logical_id) stage.StageName = self.stage_name @@ -221,76 +259,79 @@ def _construct_api_domain(self, rest_api): if self.domain is None: return None, None, None - if self.domain.get('DomainName') is None or \ - self.domain.get('CertificateArn') is None: - raise InvalidResourceException(self.logical_id, - "Custom Domains only works if both DomainName and CertificateArn" - " are provided") + if self.domain.get("DomainName") is None or self.domain.get("CertificateArn") is None: + raise InvalidResourceException( + self.logical_id, "Custom Domains only works if both DomainName and CertificateArn" " are provided" + ) - self.domain['ApiDomainName'] = "{}{}".format('ApiGatewayDomainName', - logical_id_generator. - LogicalIdGenerator("", self.domain.get('DomainName')).gen()) + self.domain["ApiDomainName"] = "{}{}".format( + "ApiGatewayDomainName", logical_id_generator.LogicalIdGenerator("", self.domain.get("DomainName")).gen() + ) - domain = ApiGatewayDomainName(self.domain.get('ApiDomainName'), - attributes=self.passthrough_resource_attributes) - domain.DomainName = self.domain.get('DomainName') - endpoint = self.domain.get('EndpointConfiguration') + domain = ApiGatewayDomainName(self.domain.get("ApiDomainName"), attributes=self.passthrough_resource_attributes) + domain.DomainName = self.domain.get("DomainName") + endpoint = self.domain.get("EndpointConfiguration") if endpoint is None: - endpoint = 'REGIONAL' - self.domain['EndpointConfiguration'] = 'REGIONAL' - elif endpoint not in ['EDGE', 'REGIONAL']: - raise InvalidResourceException(self.logical_id, - "EndpointConfiguration for Custom Domains must be" - " one of {}".format(['EDGE', 'REGIONAL'])) - - if endpoint == 'REGIONAL': - domain.RegionalCertificateArn = self.domain.get('CertificateArn') + endpoint = "REGIONAL" + self.domain["EndpointConfiguration"] = "REGIONAL" + elif endpoint not in ["EDGE", "REGIONAL"]: + raise InvalidResourceException( + self.logical_id, + "EndpointConfiguration for Custom Domains must be" " one of {}".format(["EDGE", "REGIONAL"]), + ) + + if endpoint == "REGIONAL": + domain.RegionalCertificateArn = self.domain.get("CertificateArn") else: - domain.CertificateArn = self.domain.get('CertificateArn') + domain.CertificateArn = self.domain.get("CertificateArn") domain.EndpointConfiguration = {"Types": [endpoint]} # Create BasepathMappings - if self.domain.get('BasePath') and isinstance(self.domain.get('BasePath'), string_types): - basepaths = [self.domain.get('BasePath')] - elif self.domain.get('BasePath') and isinstance(self.domain.get('BasePath'), list): - basepaths = self.domain.get('BasePath') + if self.domain.get("BasePath") and isinstance(self.domain.get("BasePath"), string_types): + basepaths = [self.domain.get("BasePath")] + elif self.domain.get("BasePath") and isinstance(self.domain.get("BasePath"), list): + basepaths = self.domain.get("BasePath") else: basepaths = None basepath_resource_list = [] if basepaths is None: - basepath_mapping = ApiGatewayBasePathMapping(self.logical_id + 'BasePathMapping', - attributes=self.passthrough_resource_attributes) - basepath_mapping.DomainName = self.domain.get('DomainName') + basepath_mapping = ApiGatewayBasePathMapping( + self.logical_id + "BasePathMapping", attributes=self.passthrough_resource_attributes + ) + basepath_mapping.DomainName = self.domain.get("DomainName") basepath_mapping.RestApiId = ref(rest_api.logical_id) - basepath_mapping.Stage = ref(rest_api.logical_id + '.Stage') + basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage") basepath_resource_list.extend([basepath_mapping]) else: for path in basepaths: - path = ''.join(e for e in path if e.isalnum()) - logical_id = "{}{}{}".format(self.logical_id, path, 'BasePathMapping') - basepath_mapping = ApiGatewayBasePathMapping(logical_id, - attributes=self.passthrough_resource_attributes) - basepath_mapping.DomainName = self.domain.get('DomainName') + path = "".join(e for e in path if e.isalnum()) + logical_id = "{}{}{}".format(self.logical_id, path, "BasePathMapping") + basepath_mapping = ApiGatewayBasePathMapping( + logical_id, attributes=self.passthrough_resource_attributes + ) + basepath_mapping.DomainName = self.domain.get("DomainName") basepath_mapping.RestApiId = ref(rest_api.logical_id) - basepath_mapping.Stage = ref(rest_api.logical_id + '.Stage') + basepath_mapping.Stage = ref(rest_api.logical_id + ".Stage") basepath_mapping.BasePath = path basepath_resource_list.extend([basepath_mapping]) # Create the Route53 RecordSetGroup resource record_set_group = None - if self.domain.get('Route53') is not None: - route53 = self.domain.get('Route53') - if route53.get('HostedZoneId') is None: - raise InvalidResourceException(self.logical_id, - "HostedZoneId is required to enable Route53 support on Custom Domains.") - logical_id = logical_id_generator.LogicalIdGenerator("", route53.get('HostedZoneId')).gen() - record_set_group = Route53RecordSetGroup('RecordSetGroup' + logical_id, - attributes=self.passthrough_resource_attributes) - record_set_group.HostedZoneId = route53.get('HostedZoneId') + if self.domain.get("Route53") is not None: + route53 = self.domain.get("Route53") + if route53.get("HostedZoneId") is None: + raise InvalidResourceException( + self.logical_id, "HostedZoneId is required to enable Route53 support on Custom Domains." + ) + logical_id = logical_id_generator.LogicalIdGenerator("", route53.get("HostedZoneId")).gen() + record_set_group = Route53RecordSetGroup( + "RecordSetGroup" + logical_id, attributes=self.passthrough_resource_attributes + ) + record_set_group.HostedZoneId = route53.get("HostedZoneId") record_set_group.RecordSets = self._construct_record_sets_for_domain(self.domain) return domain, basepath_resource_list, record_set_group @@ -298,37 +339,37 @@ def _construct_api_domain(self, rest_api): def _construct_record_sets_for_domain(self, domain): recordset_list = [] recordset = {} - route53 = domain.get('Route53') + route53 = domain.get("Route53") - recordset['Name'] = domain.get('DomainName') - recordset['Type'] = 'A' - recordset['AliasTarget'] = self._construct_alias_target(self.domain) + recordset["Name"] = domain.get("DomainName") + recordset["Type"] = "A" + recordset["AliasTarget"] = self._construct_alias_target(self.domain) recordset_list.extend([recordset]) recordset_ipv6 = {} - if route53.get('IpV6') is not None and route53.get('IpV6') is True: - recordset_ipv6['Name'] = domain.get('DomainName') - recordset_ipv6['Type'] = 'AAAA' - recordset_ipv6['AliasTarget'] = self._construct_alias_target(self.domain) + if route53.get("IpV6") is not None and route53.get("IpV6") is True: + recordset_ipv6["Name"] = domain.get("DomainName") + recordset_ipv6["Type"] = "AAAA" + recordset_ipv6["AliasTarget"] = self._construct_alias_target(self.domain) recordset_list.extend([recordset_ipv6]) return recordset_list def _construct_alias_target(self, domain): alias_target = {} - route53 = domain.get('Route53') - target_health = route53.get('EvaluateTargetHealth') + route53 = domain.get("Route53") + target_health = route53.get("EvaluateTargetHealth") if target_health is not None: - alias_target['EvaluateTargetHealth'] = target_health - if domain.get('EndpointConfiguration') == 'REGIONAL': - alias_target['HostedZoneId'] = fnGetAtt(self.domain.get('ApiDomainName'), 'RegionalHostedZoneId') - alias_target['DNSName'] = fnGetAtt(self.domain.get('ApiDomainName'), 'RegionalDomainName') + alias_target["EvaluateTargetHealth"] = target_health + if domain.get("EndpointConfiguration") == "REGIONAL": + alias_target["HostedZoneId"] = fnGetAtt(self.domain.get("ApiDomainName"), "RegionalHostedZoneId") + alias_target["DNSName"] = fnGetAtt(self.domain.get("ApiDomainName"), "RegionalDomainName") else: - if route53.get('DistributionDomainName') is None: - route53['DistributionDomainName'] = fnGetAtt(self.domain.get('ApiDomainName'), 'DistributionDomainName') - alias_target['HostedZoneId'] = 'Z2FDTNDATAQYW2' - alias_target['DNSName'] = route53.get('DistributionDomainName') + if route53.get("DistributionDomainName") is None: + route53["DistributionDomainName"] = fnGetAtt(self.domain.get("ApiDomainName"), "DistributionDomainName") + alias_target["HostedZoneId"] = "Z2FDTNDATAQYW2" + alias_target["DNSName"] = route53.get("DistributionDomainName") return alias_target def to_cloudformation(self): @@ -363,9 +404,9 @@ def _add_cors(self): return if self.cors and not self.definition_body: - raise InvalidResourceException(self.logical_id, - "Cors works only with inline Swagger specified in " - "'DefinitionBody' property") + raise InvalidResourceException( + self.logical_id, "Cors works only with inline Swagger specified in " "'DefinitionBody' property" + ) if isinstance(self.cors, string_types) or is_instrinsic(self.cors): # Just set Origin property. Others will be defaults @@ -382,18 +423,29 @@ def _add_cors(self): raise InvalidResourceException(self.logical_id, INVALID_ERROR) if not SwaggerEditor.is_valid(self.definition_body): - raise InvalidResourceException(self.logical_id, "Unable to add Cors configuration because " - "'DefinitionBody' does not contain a valid Swagger") + raise InvalidResourceException( + self.logical_id, + "Unable to add Cors configuration because " "'DefinitionBody' does not contain a valid Swagger", + ) if properties.AllowCredentials is True and properties.AllowOrigin == _CORS_WILDCARD: - raise InvalidResourceException(self.logical_id, "Unable to add Cors configuration because " - "'AllowCredentials' can not be true when " - "'AllowOrigin' is \"'*'\" or not set") + raise InvalidResourceException( + self.logical_id, + "Unable to add Cors configuration because " + "'AllowCredentials' can not be true when " + "'AllowOrigin' is \"'*'\" or not set", + ) editor = SwaggerEditor(self.definition_body) for path in editor.iter_on_path(): - editor.add_cors(path, properties.AllowOrigin, properties.AllowHeaders, properties.AllowMethods, - max_age=properties.MaxAge, allow_credentials=properties.AllowCredentials) + editor.add_cors( + path, + properties.AllowOrigin, + properties.AllowHeaders, + properties.AllowMethods, + max_age=properties.MaxAge, + allow_credentials=properties.AllowCredentials, + ) # Assign the Swagger back to template self.definition_body = editor.swagger @@ -425,27 +477,32 @@ def _add_auth(self): return if self.auth and not self.definition_body: - raise InvalidResourceException(self.logical_id, - "Auth works only with inline Swagger specified in " - "'DefinitionBody' property") + raise InvalidResourceException( + self.logical_id, "Auth works only with inline Swagger specified in " "'DefinitionBody' property" + ) # Make sure keys in the dict are recognized if not all(key in AuthProperties._fields for key in self.auth.keys()): - raise InvalidResourceException( - self.logical_id, "Invalid value for 'Auth' property") + raise InvalidResourceException(self.logical_id, "Invalid value for 'Auth' property") if not SwaggerEditor.is_valid(self.definition_body): - raise InvalidResourceException(self.logical_id, "Unable to add Auth configuration because " - "'DefinitionBody' does not contain a valid Swagger") + raise InvalidResourceException( + self.logical_id, + "Unable to add Auth configuration because " "'DefinitionBody' does not contain a valid Swagger", + ) swagger_editor = SwaggerEditor(self.definition_body) auth_properties = AuthProperties(**self.auth) authorizers = self._get_authorizers(auth_properties.Authorizers, auth_properties.DefaultAuthorizer) if authorizers: swagger_editor.add_authorizers_security_definitions(authorizers) - self._set_default_authorizer(swagger_editor, authorizers, auth_properties.DefaultAuthorizer, - auth_properties.AddDefaultAuthorizerToCorsPreflight, - auth_properties.Authorizers) + self._set_default_authorizer( + swagger_editor, + authorizers, + auth_properties.DefaultAuthorizer, + auth_properties.AddDefaultAuthorizerToCorsPreflight, + auth_properties.Authorizers, + ) if auth_properties.ApiKeyRequired: swagger_editor.add_apikey_security_definition() @@ -453,8 +510,9 @@ def _add_auth(self): if auth_properties.ResourcePolicy: for path in swagger_editor.iter_on_path(): - swagger_editor.add_resource_policy(auth_properties.ResourcePolicy, path, - self.logical_id, self.stage_name) + swagger_editor.add_resource_policy( + auth_properties.ResourcePolicy, path, self.logical_id, self.stage_name + ) self.definition_body = self._openapi_postprocess(swagger_editor.swagger) @@ -468,8 +526,9 @@ def _add_gateway_responses(self): if self.gateway_responses and not self.definition_body: raise InvalidResourceException( - self.logical_id, "GatewayResponses works only with inline Swagger specified in " - "'DefinitionBody' property") + self.logical_id, + "GatewayResponses works only with inline Swagger specified in " "'DefinitionBody' property", + ) # Make sure keys in the dict are recognized for responses_key, responses_value in self.gateway_responses.items(): @@ -477,12 +536,14 @@ def _add_gateway_responses(self): if response_key not in GatewayResponseProperties: raise InvalidResourceException( self.logical_id, - "Invalid property '{}' in 'GatewayResponses' property '{}'".format(response_key, responses_key)) + "Invalid property '{}' in 'GatewayResponses' property '{}'".format(response_key, responses_key), + ) if not SwaggerEditor.is_valid(self.definition_body): raise InvalidResourceException( - self.logical_id, "Unable to add Auth configuration because " - "'DefinitionBody' does not contain a valid Swagger") + self.logical_id, + "Unable to add Auth configuration because " "'DefinitionBody' does not contain a valid Swagger", + ) swagger_editor = SwaggerEditor(self.definition_body) @@ -490,9 +551,9 @@ def _add_gateway_responses(self): for response_type, response in self.gateway_responses.items(): gateway_responses[response_type] = ApiGatewayResponse( api_logical_id=self.logical_id, - response_parameters=response.get('ResponseParameters', {}), - response_templates=response.get('ResponseTemplates', {}), - status_code=response.get('StatusCode', None) + response_parameters=response.get("ResponseParameters", {}), + response_templates=response.get("ResponseTemplates", {}), + status_code=response.get("StatusCode", None), ) if gateway_responses: @@ -511,13 +572,15 @@ def _add_models(self): return if self.models and not self.definition_body: - raise InvalidResourceException(self.logical_id, - "Models works only with inline Swagger specified in " - "'DefinitionBody' property") + raise InvalidResourceException( + self.logical_id, "Models works only with inline Swagger specified in " "'DefinitionBody' property" + ) if not SwaggerEditor.is_valid(self.definition_body): - raise InvalidResourceException(self.logical_id, "Unable to add Models definitions because " - "'DefinitionBody' does not contain a valid Swagger") + raise InvalidResourceException( + self.logical_id, + "Unable to add Models definitions because " "'DefinitionBody' does not contain a valid Swagger", + ) if not all(isinstance(model, dict) for model in self.models.values()): raise InvalidResourceException(self.logical_id, "Invalid value for 'Models' property") @@ -536,25 +599,25 @@ def _openapi_postprocess(self, definition_body): If the is swagger defined in the definition body, we treat it as a swagger spec and do not make any openapi 3 changes to it """ - if definition_body.get('swagger') is not None: + if definition_body.get("swagger") is not None: return definition_body - if definition_body.get('openapi') is not None and self.open_api_version is None: - self.open_api_version = definition_body.get('openapi') - - if self.open_api_version and \ - SwaggerEditor.safe_compare_regex_with_string(SwaggerEditor.get_openapi_version_3_regex(), - self.open_api_version): - if definition_body.get('securityDefinitions'): - components = definition_body.get('components', {}) - components['securitySchemes'] = definition_body['securityDefinitions'] - definition_body['components'] = components - del definition_body['securityDefinitions'] - if definition_body.get('definitions'): - components = definition_body.get('components', {}) - components['schemas'] = definition_body['definitions'] - definition_body['components'] = components - del definition_body['definitions'] + if definition_body.get("openapi") is not None and self.open_api_version is None: + self.open_api_version = definition_body.get("openapi") + + if self.open_api_version and SwaggerEditor.safe_compare_regex_with_string( + SwaggerEditor.get_openapi_version_3_regex(), self.open_api_version + ): + if definition_body.get("securityDefinitions"): + components = definition_body.get("components", {}) + components["securitySchemes"] = definition_body["securityDefinitions"] + definition_body["components"] = components + del definition_body["securityDefinitions"] + if definition_body.get("definitions"): + components = definition_body.get("components", {}) + components["schemas"] = definition_body["definitions"] + definition_body["components"] = components + del definition_body["definitions"] # removes `consumes` and `produces` options for CORS in openapi3 and # adds `schema` for the headers in responses for openapi3 if definition_body.get("paths"): @@ -568,49 +631,54 @@ def _openapi_postprocess(self, definition_body): # add schema for the headers in options section for openapi3 if field in ["responses"]: options_path = definition_body["paths"][path]["options"] - if options_path and options_path.get(field).get('200') and options_path.get(field).\ - get('200').get('headers'): - headers = definition_body["paths"][path]["options"][field]['200']['headers'] + if ( + options_path + and options_path.get(field).get("200") + and options_path.get(field).get("200").get("headers") + ): + headers = definition_body["paths"][path]["options"][field]["200"]["headers"] for header in headers.keys(): - header_value = {"schema": definition_body["paths"][path]["options"][field] - ['200']['headers'][header]} - definition_body["paths"][path]["options"][field]['200']['headers'][header] = \ - header_value + header_value = { + "schema": definition_body["paths"][path]["options"][field]["200"][ + "headers" + ][header] + } + definition_body["paths"][path]["options"][field]["200"]["headers"][ + header + ] = header_value return definition_body def _get_authorizers(self, authorizers_config, default_authorizer=None): authorizers = {} - if default_authorizer == 'AWS_IAM': + if default_authorizer == "AWS_IAM": authorizers[default_authorizer] = ApiGatewayAuthorizer( - api_logical_id=self.logical_id, - name=default_authorizer, - is_aws_iam_authorizer=True + api_logical_id=self.logical_id, name=default_authorizer, is_aws_iam_authorizer=True ) if not authorizers_config: - if 'AWS_IAM' in authorizers: + if "AWS_IAM" in authorizers: return authorizers return None if not isinstance(authorizers_config, dict): - raise InvalidResourceException(self.logical_id, - "Authorizers must be a dictionary") + raise InvalidResourceException(self.logical_id, "Authorizers must be a dictionary") for authorizer_name, authorizer in authorizers_config.items(): if not isinstance(authorizer, dict): - raise InvalidResourceException(self.logical_id, - "Authorizer %s must be a dictionary." % (authorizer_name)) + raise InvalidResourceException( + self.logical_id, "Authorizer %s must be a dictionary." % (authorizer_name) + ) authorizers[authorizer_name] = ApiGatewayAuthorizer( api_logical_id=self.logical_id, name=authorizer_name, - user_pool_arn=authorizer.get('UserPoolArn'), - function_arn=authorizer.get('FunctionArn'), - identity=authorizer.get('Identity'), - function_payload_type=authorizer.get('FunctionPayloadType'), - function_invoke_role=authorizer.get('FunctionInvokeRole'), - authorization_scopes=authorizer.get("AuthorizationScopes") + user_pool_arn=authorizer.get("UserPoolArn"), + function_arn=authorizer.get("FunctionArn"), + identity=authorizer.get("Identity"), + function_payload_type=authorizer.get("FunctionPayloadType"), + function_invoke_role=authorizer.get("FunctionInvokeRole"), + authorization_scopes=authorizer.get("AuthorizationScopes"), ) return authorizers @@ -621,18 +689,21 @@ def _get_permission(self, authorizer_name, authorizer_lambda_function_arn): :rtype: model.lambda_.LambdaPermission """ rest_api = ApiGatewayRestApi(self.logical_id, depends_on=self.depends_on, attributes=self.resource_attributes) - api_id = rest_api.get_runtime_attr('rest_api_id') + api_id = rest_api.get_runtime_attr("rest_api_id") partition = ArnGenerator.get_partition_name() - resource = '${__ApiId__}/authorizers/*' - source_arn = fnSub(ArnGenerator.generate_arn(partition=partition, service='execute-api', resource=resource), - {"__ApiId__": api_id}) - - lambda_permission = LambdaPermission(self.logical_id + authorizer_name + 'AuthorizerPermission', - attributes=self.passthrough_resource_attributes) - lambda_permission.Action = 'lambda:InvokeFunction' + resource = "${__ApiId__}/authorizers/*" + source_arn = fnSub( + ArnGenerator.generate_arn(partition=partition, service="execute-api", resource=resource), + {"__ApiId__": api_id}, + ) + + lambda_permission = LambdaPermission( + self.logical_id + authorizer_name + "AuthorizerPermission", attributes=self.passthrough_resource_attributes + ) + lambda_permission.Action = "lambda:InvokeFunction" lambda_permission.FunctionName = authorizer_lambda_function_arn - lambda_permission.Principal = 'apigateway.amazonaws.com' + lambda_permission.Principal = "apigateway.amazonaws.com" lambda_permission.SourceArn = source_arn return lambda_permission @@ -659,19 +730,26 @@ def _construct_authorizer_lambda_permission(self): return permissions - def _set_default_authorizer(self, swagger_editor, authorizers, default_authorizer, - add_default_auth_to_preflight=True, api_authorizers=None): + def _set_default_authorizer( + self, swagger_editor, authorizers, default_authorizer, add_default_auth_to_preflight=True, api_authorizers=None + ): if not default_authorizer: return - if not authorizers.get(default_authorizer) and default_authorizer != 'AWS_IAM': - raise InvalidResourceException(self.logical_id, "Unable to set DefaultAuthorizer because '" + - default_authorizer + "' was not defined in 'Authorizers'") + if not authorizers.get(default_authorizer) and default_authorizer != "AWS_IAM": + raise InvalidResourceException( + self.logical_id, + "Unable to set DefaultAuthorizer because '" + default_authorizer + "' was not defined in 'Authorizers'", + ) for path in swagger_editor.iter_on_path(): - swagger_editor.set_path_default_authorizer(path, default_authorizer, authorizers=authorizers, - add_default_auth_to_preflight=add_default_auth_to_preflight, - api_authorizers=api_authorizers) + swagger_editor.set_path_default_authorizer( + path, + default_authorizer, + authorizers=authorizers, + add_default_auth_to_preflight=add_default_auth_to_preflight, + api_authorizers=api_authorizers, + ) def _set_default_apikey_required(self, swagger_editor): for path in swagger_editor.iter_on_path(): diff --git a/samtranslator/model/api/http_api_generator.py b/samtranslator/model/api/http_api_generator.py index 4c21f1c912..b3d465cfa4 100644 --- a/samtranslator/model/api/http_api_generator.py +++ b/samtranslator/model/api/http_api_generator.py @@ -14,10 +14,20 @@ class HttpApiGenerator(object): - - def __init__(self, logical_id, stage_variables, depends_on, definition_body, definition_uri, - stage_name, tags=None, auth=None, access_log_settings=None, - resource_attributes=None, passthrough_resource_attributes=None): + def __init__( + self, + logical_id, + stage_variables, + depends_on, + definition_body, + definition_uri, + stage_name, + tags=None, + auth=None, + access_log_settings=None, + resource_attributes=None, + passthrough_resource_attributes=None, + ): """Constructs an API Generator class that generates API Gateway resources :param logical_id: Logical id of the SAM API Resource @@ -55,8 +65,9 @@ def _construct_http_api(self): http_api = ApiGatewayV2HttpApi(self.logical_id, depends_on=self.depends_on, attributes=self.resource_attributes) if self.definition_uri and self.definition_body: - raise InvalidResourceException(self.logical_id, - "Specify either 'DefinitionUri' or 'DefinitionBody' property and not both") + raise InvalidResourceException( + self.logical_id, "Specify either 'DefinitionUri' or 'DefinitionBody' property and not both" + ) self._add_auth() @@ -65,10 +76,12 @@ def _construct_http_api(self): elif self.definition_body: http_api.Body = self.definition_body else: - raise InvalidResourceException(self.logical_id, - "'DefinitionUri' or 'DefinitionBody' are required properties of an " - "'AWS::Serverless::HttpApi'. Add a value for one of these properties or " - "add a 'HttpApi' event to an 'AWS::Serverless::Function'") + raise InvalidResourceException( + self.logical_id, + "'DefinitionUri' or 'DefinitionBody' are required properties of an " + "'AWS::Serverless::HttpApi'. Add a value for one of these properties or " + "add a 'HttpApi' event to an 'AWS::Serverless::Function'", + ) if self.tags is not None: http_api.Tags = get_tag_list(self.tags) @@ -83,26 +96,28 @@ def _add_auth(self): return if self.auth and not self.definition_body: - raise InvalidResourceException(self.logical_id, - "Auth works only with inline Swagger specified in " - "'DefinitionBody' property") + raise InvalidResourceException( + self.logical_id, "Auth works only with inline Swagger specified in " "'DefinitionBody' property" + ) # Make sure keys in the dict are recognized if not all(key in AuthProperties._fields for key in self.auth.keys()): - raise InvalidResourceException( - self.logical_id, "Invalid value for 'Auth' property") + raise InvalidResourceException(self.logical_id, "Invalid value for 'Auth' property") if not OpenApiEditor.is_valid(self.definition_body): - raise InvalidResourceException(self.logical_id, "Unable to add Auth configuration because " - "'DefinitionBody' does not contain a valid Swagger") + raise InvalidResourceException( + self.logical_id, + "Unable to add Auth configuration because " "'DefinitionBody' does not contain a valid Swagger", + ) open_api_editor = OpenApiEditor(self.definition_body) auth_properties = AuthProperties(**self.auth) authorizers = self._get_authorizers(auth_properties.Authorizers, auth_properties.DefaultAuthorizer) # authorizers is guaranteed to return a value or raise an exception open_api_editor.add_authorizers_security_definitions(authorizers) - self._set_default_authorizer(open_api_editor, authorizers, auth_properties.DefaultAuthorizer, - auth_properties.Authorizers) + self._set_default_authorizer( + open_api_editor, authorizers, auth_properties.DefaultAuthorizer, auth_properties.Authorizers + ) self.definition_body = open_api_editor.openapi def _set_default_authorizer(self, open_api_editor, authorizers, default_authorizer, api_authorizers): @@ -117,12 +132,15 @@ def _set_default_authorizer(self, open_api_editor, authorizers, default_authoriz return if not authorizers.get(default_authorizer): - raise InvalidResourceException(self.logical_id, "Unable to set DefaultAuthorizer because '" + - default_authorizer + "' was not defined in 'Authorizers'") + raise InvalidResourceException( + self.logical_id, + "Unable to set DefaultAuthorizer because '" + default_authorizer + "' was not defined in 'Authorizers'", + ) for path in open_api_editor.iter_on_path(): - open_api_editor.set_path_default_authorizer(path, default_authorizer, authorizers=authorizers, - api_authorizers=api_authorizers) + open_api_editor.set_path_default_authorizer( + path, default_authorizer, authorizers=authorizers, api_authorizers=api_authorizers + ) def _get_authorizers(self, authorizers_config, default_authorizer=None): """ @@ -133,21 +151,21 @@ def _get_authorizers(self, authorizers_config, default_authorizer=None): authorizers = {} if not isinstance(authorizers_config, dict): - raise InvalidResourceException(self.logical_id, - "Authorizers must be a dictionary") + raise InvalidResourceException(self.logical_id, "Authorizers must be a dictionary") for authorizer_name, authorizer in authorizers_config.items(): if not isinstance(authorizer, dict): - raise InvalidResourceException(self.logical_id, - "Authorizer %s must be a dictionary." % (authorizer_name)) + raise InvalidResourceException( + self.logical_id, "Authorizer %s must be a dictionary." % (authorizer_name) + ) authorizers[authorizer_name] = ApiGatewayV2Authorizer( api_logical_id=self.logical_id, name=authorizer_name, - open_id_connect_url=authorizer.get('OpenIdConnectUrl'), - authorization_scopes=authorizer.get('AuthorizationScopes'), - jwt_configuration=authorizer.get('JwtConfiguration'), - id_source=authorizer.get('IdentitySource') + open_id_connect_url=authorizer.get("OpenIdConnectUrl"), + authorization_scopes=authorizer.get("AuthorizationScopes"), + jwt_configuration=authorizer.get("JwtConfiguration"), + id_source=authorizer.get("IdentitySource"), ) return authorizers @@ -160,24 +178,24 @@ def _construct_body_s3_dict(self): if isinstance(self.definition_uri, dict): if not self.definition_uri.get("Bucket", None) or not self.definition_uri.get("Key", None): # DefinitionUri is a dictionary but does not contain Bucket or Key property - raise InvalidResourceException(self.logical_id, - "'DefinitionUri' requires Bucket and Key properties to be specified") + raise InvalidResourceException( + self.logical_id, "'DefinitionUri' requires Bucket and Key properties to be specified" + ) s3_pointer = self.definition_uri else: # DefinitionUri is a string s3_pointer = parse_s3_uri(self.definition_uri) if s3_pointer is None: - raise InvalidResourceException(self.logical_id, - '\'DefinitionUri\' is not a valid S3 Uri of the form ' - '"s3://bucket/key" with optional versionId query parameter.') - - body_s3 = { - 'Bucket': s3_pointer['Bucket'], - 'Key': s3_pointer['Key'] - } - if 'Version' in s3_pointer: - body_s3['Version'] = s3_pointer['Version'] + raise InvalidResourceException( + self.logical_id, + "'DefinitionUri' is not a valid S3 Uri of the form " + '"s3://bucket/key" with optional versionId query parameter.', + ) + + body_s3 = {"Bucket": s3_pointer["Bucket"], "Key": s3_pointer["Key"]} + if "Version" in s3_pointer: + body_s3["Version"] = s3_pointer["Version"] return body_s3 def _construct_stage(self): @@ -201,8 +219,7 @@ def _construct_stage(self): else: generator = logical_id_generator.LogicalIdGenerator(self.logical_id + "Stage", stage_name_prefix) stage_logical_id = generator.gen() - stage = ApiGatewayV2Stage(stage_logical_id, - attributes=self.passthrough_resource_attributes) + stage = ApiGatewayV2Stage(stage_logical_id, attributes=self.passthrough_resource_attributes) stage.ApiId = ref(self.logical_id) stage.StageName = self.stage_name stage.StageVariables = self.stage_variables diff --git a/samtranslator/model/apigateway.py b/samtranslator/model/apigateway.py index 43a824be4b..d767e4adcc 100644 --- a/samtranslator/model/apigateway.py +++ b/samtranslator/model/apigateway.py @@ -9,72 +9,64 @@ class ApiGatewayRestApi(Resource): - resource_type = 'AWS::ApiGateway::RestApi' + resource_type = "AWS::ApiGateway::RestApi" property_types = { - 'Body': PropertyType(False, is_type(dict)), - 'BodyS3Location': PropertyType(False, is_type(dict)), - 'CloneFrom': PropertyType(False, is_str()), - 'Description': PropertyType(False, is_str()), - 'FailOnWarnings': PropertyType(False, is_type(bool)), - 'Name': PropertyType(False, is_str()), - 'Parameters': PropertyType(False, is_type(dict)), - 'EndpointConfiguration': PropertyType(False, is_type(dict)), + "Body": PropertyType(False, is_type(dict)), + "BodyS3Location": PropertyType(False, is_type(dict)), + "CloneFrom": PropertyType(False, is_str()), + "Description": PropertyType(False, is_str()), + "FailOnWarnings": PropertyType(False, is_type(bool)), + "Name": PropertyType(False, is_str()), + "Parameters": PropertyType(False, is_type(dict)), + "EndpointConfiguration": PropertyType(False, is_type(dict)), "BinaryMediaTypes": PropertyType(False, is_type(list)), - "MinimumCompressionSize": PropertyType(False, is_type(int)) + "MinimumCompressionSize": PropertyType(False, is_type(int)), } - runtime_attrs = { - "rest_api_id": lambda self: ref(self.logical_id), - } + runtime_attrs = {"rest_api_id": lambda self: ref(self.logical_id)} class ApiGatewayStage(Resource): - resource_type = 'AWS::ApiGateway::Stage' + resource_type = "AWS::ApiGateway::Stage" property_types = { - 'AccessLogSetting': PropertyType(False, is_type(dict)), - 'CacheClusterEnabled': PropertyType(False, is_type(bool)), - 'CacheClusterSize': PropertyType(False, is_str()), - 'CanarySetting': PropertyType(False, is_type(dict)), - 'ClientCertificateId': PropertyType(False, is_str()), - 'DeploymentId': PropertyType(True, is_str()), - 'Description': PropertyType(False, is_str()), - 'RestApiId': PropertyType(True, is_str()), - 'StageName': PropertyType(True, one_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, list_of(is_type(dict))), - 'TracingEnabled': PropertyType(False, is_type(bool)), - 'Variables': PropertyType(False, is_type(dict)), - "MethodSettings": PropertyType(False, is_type(list)) + "AccessLogSetting": PropertyType(False, is_type(dict)), + "CacheClusterEnabled": PropertyType(False, is_type(bool)), + "CacheClusterSize": PropertyType(False, is_str()), + "CanarySetting": PropertyType(False, is_type(dict)), + "ClientCertificateId": PropertyType(False, is_str()), + "DeploymentId": PropertyType(True, is_str()), + "Description": PropertyType(False, is_str()), + "RestApiId": PropertyType(True, is_str()), + "StageName": PropertyType(True, one_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, list_of(is_type(dict))), + "TracingEnabled": PropertyType(False, is_type(bool)), + "Variables": PropertyType(False, is_type(dict)), + "MethodSettings": PropertyType(False, is_type(list)), } - runtime_attrs = { - "stage_name": lambda self: ref(self.logical_id), - } + runtime_attrs = {"stage_name": lambda self: ref(self.logical_id)} def update_deployment_ref(self, deployment_logical_id): self.DeploymentId = ref(deployment_logical_id) class ApiGatewayAccount(Resource): - resource_type = 'AWS::ApiGateway::Account' - property_types = { - 'CloudWatchRoleArn': PropertyType(False, one_of(is_str(), is_type(dict))) - } + resource_type = "AWS::ApiGateway::Account" + property_types = {"CloudWatchRoleArn": PropertyType(False, one_of(is_str(), is_type(dict)))} class ApiGatewayDeployment(Resource): _X_HASH_DELIMITER = "||" - resource_type = 'AWS::ApiGateway::Deployment' + resource_type = "AWS::ApiGateway::Deployment" property_types = { - 'Description': PropertyType(False, is_str()), - 'RestApiId': PropertyType(True, is_str()), - 'StageDescription': PropertyType(False, is_type(dict)), - 'StageName': PropertyType(False, is_str()) + "Description": PropertyType(False, is_str()), + "RestApiId": PropertyType(True, is_str()), + "StageDescription": PropertyType(False, is_type(dict)), + "StageName": PropertyType(False, is_str()), } - runtime_attrs = { - "deployment_id": lambda self: ref(self.logical_id), - } + runtime_attrs = {"deployment_id": lambda self: ref(self.logical_id)} def make_auto_deployable(self, stage, openapi_version=None, swagger=None, domain=None): """ @@ -116,12 +108,12 @@ def __init__(self, api_logical_id=None, response_parameters=None, response_templ for response_parameter_key in response_parameters.keys(): if response_parameter_key not in ApiGatewayResponse.ResponseParameterProperties: raise InvalidResourceException( - api_logical_id, - "Invalid gateway response parameter '{}'".format(response_parameter_key)) + api_logical_id, "Invalid gateway response parameter '{}'".format(response_parameter_key) + ) status_code_str = self._status_code_string(status_code) # status_code must look like a status code, if present. Let's not be judgmental; just check 0-999. - if status_code and not match(r'^[0-9]{1,3}$', status_code_str): + if status_code and not match(r"^[0-9]{1,3}$", status_code_str): raise InvalidResourceException(api_logical_id, "Property 'StatusCode' must be numeric") self.api_logical_id = api_logical_id @@ -132,7 +124,7 @@ def __init__(self, api_logical_id=None, response_parameters=None, response_templ def generate_swagger(self): swagger = { "responseParameters": self._add_prefixes(self.response_parameters), - "responseTemplates": self.response_templates + "responseTemplates": self.response_templates, } # Prevent "null" being written. @@ -142,14 +134,14 @@ def generate_swagger(self): return swagger def _add_prefixes(self, response_parameters): - GATEWAY_RESPONSE_PREFIX = 'gatewayresponse.' + GATEWAY_RESPONSE_PREFIX = "gatewayresponse." prefixed_parameters = {} - for key, value in response_parameters.get('Headers', {}).items(): - prefixed_parameters[GATEWAY_RESPONSE_PREFIX + 'header.' + key] = value - for key, value in response_parameters.get('Paths', {}).items(): - prefixed_parameters[GATEWAY_RESPONSE_PREFIX + 'path.' + key] = value - for key, value in response_parameters.get('QueryStrings', {}).items(): - prefixed_parameters[GATEWAY_RESPONSE_PREFIX + 'querystring.' + key] = value + for key, value in response_parameters.get("Headers", {}).items(): + prefixed_parameters[GATEWAY_RESPONSE_PREFIX + "header." + key] = value + for key, value in response_parameters.get("Paths", {}).items(): + prefixed_parameters[GATEWAY_RESPONSE_PREFIX + "path." + key] = value + for key, value in response_parameters.get("QueryStrings", {}).items(): + prefixed_parameters[GATEWAY_RESPONSE_PREFIX + "querystring." + key] = value return prefixed_parameters @@ -158,38 +150,51 @@ def _status_code_string(self, status_code): class ApiGatewayDomainName(Resource): - resource_type = 'AWS::ApiGateway::DomainName' + resource_type = "AWS::ApiGateway::DomainName" property_types = { - 'RegionalCertificateArn': PropertyType(False, is_str()), - 'DomainName': PropertyType(True, is_str()), - 'EndpointConfiguration': PropertyType(False, is_type(dict)), - 'CertificateArn': PropertyType(False, is_str()) + "RegionalCertificateArn": PropertyType(False, is_str()), + "DomainName": PropertyType(True, is_str()), + "EndpointConfiguration": PropertyType(False, is_type(dict)), + "CertificateArn": PropertyType(False, is_str()), } class ApiGatewayBasePathMapping(Resource): - resource_type = 'AWS::ApiGateway::BasePathMapping' + resource_type = "AWS::ApiGateway::BasePathMapping" property_types = { - 'BasePath': PropertyType(False, is_str()), - 'DomainName': PropertyType(True, is_str()), - 'RestApiId': PropertyType(False, is_str()), - 'Stage': PropertyType(False, is_str()) + "BasePath": PropertyType(False, is_str()), + "DomainName": PropertyType(True, is_str()), + "RestApiId": PropertyType(False, is_str()), + "Stage": PropertyType(False, is_str()), } class ApiGatewayAuthorizer(object): - _VALID_FUNCTION_PAYLOAD_TYPES = [None, 'TOKEN', 'REQUEST'] - - def __init__(self, api_logical_id=None, name=None, user_pool_arn=None, function_arn=None, identity=None, - function_payload_type=None, function_invoke_role=None, is_aws_iam_authorizer=False, - authorization_scopes=[]): + _VALID_FUNCTION_PAYLOAD_TYPES = [None, "TOKEN", "REQUEST"] + + def __init__( + self, + api_logical_id=None, + name=None, + user_pool_arn=None, + function_arn=None, + identity=None, + function_payload_type=None, + function_invoke_role=None, + is_aws_iam_authorizer=False, + authorization_scopes=[], + ): if function_payload_type not in ApiGatewayAuthorizer._VALID_FUNCTION_PAYLOAD_TYPES: - raise InvalidResourceException(api_logical_id, name + " Authorizer has invalid " - "'FunctionPayloadType': " + function_payload_type) + raise InvalidResourceException( + api_logical_id, name + " Authorizer has invalid " "'FunctionPayloadType': " + function_payload_type + ) - if function_payload_type == 'REQUEST' and self._is_missing_identity_source(identity): - raise InvalidResourceException(api_logical_id, name + " Authorizer must specify Identity with at least one " - "of Headers, QueryStrings, StageVariables, or Context.") + if function_payload_type == "REQUEST" and self._is_missing_identity_source(identity): + raise InvalidResourceException( + api_logical_id, + name + " Authorizer must specify Identity with at least one " + "of Headers, QueryStrings, StageVariables, or Context.", + ) self.api_logical_id = api_logical_id self.name = name @@ -205,10 +210,10 @@ def _is_missing_identity_source(self, identity): if not identity: return True - headers = identity.get('Headers') - query_strings = identity.get('QueryStrings') - stage_variables = identity.get('StageVariables') - context = identity.get('Context') + headers = identity.get("Headers") + query_strings = identity.get("QueryStrings") + stage_variables = identity.get("StageVariables") + context = identity.get("Context") if not headers and not query_strings and not stage_variables and not context: return True @@ -217,56 +222,57 @@ def _is_missing_identity_source(self, identity): def generate_swagger(self): authorizer_type = self._get_type() - APIGATEWAY_AUTHORIZER_KEY = 'x-amazon-apigateway-authorizer' + APIGATEWAY_AUTHORIZER_KEY = "x-amazon-apigateway-authorizer" swagger = { "type": "apiKey", "name": self._get_swagger_header_name(), "in": "header", - "x-amazon-apigateway-authtype": self._get_swagger_authtype() + "x-amazon-apigateway-authtype": self._get_swagger_authtype(), } - if authorizer_type == 'COGNITO_USER_POOLS': + if authorizer_type == "COGNITO_USER_POOLS": swagger[APIGATEWAY_AUTHORIZER_KEY] = { - 'type': self._get_swagger_authorizer_type(), - 'providerARNs': self._get_user_pool_arn_array() + "type": self._get_swagger_authorizer_type(), + "providerARNs": self._get_user_pool_arn_array(), } - elif authorizer_type == 'LAMBDA': - swagger[APIGATEWAY_AUTHORIZER_KEY] = { - 'type': self._get_swagger_authorizer_type() - } + elif authorizer_type == "LAMBDA": + swagger[APIGATEWAY_AUTHORIZER_KEY] = {"type": self._get_swagger_authorizer_type()} partition = ArnGenerator.get_partition_name() - resource = 'lambda:path/2015-03-31/functions/${__FunctionArn__}/invocations' - authorizer_uri = fnSub(ArnGenerator.generate_arn(partition=partition, service='apigateway', - resource=resource, include_account_id=False), - {'__FunctionArn__': self.function_arn}) - - swagger[APIGATEWAY_AUTHORIZER_KEY]['authorizerUri'] = authorizer_uri + resource = "lambda:path/2015-03-31/functions/${__FunctionArn__}/invocations" + authorizer_uri = fnSub( + ArnGenerator.generate_arn( + partition=partition, service="apigateway", resource=resource, include_account_id=False + ), + {"__FunctionArn__": self.function_arn}, + ) + + swagger[APIGATEWAY_AUTHORIZER_KEY]["authorizerUri"] = authorizer_uri reauthorize_every = self._get_reauthorize_every() function_invoke_role = self._get_function_invoke_role() if reauthorize_every is not None: - swagger[APIGATEWAY_AUTHORIZER_KEY]['authorizerResultTtlInSeconds'] = reauthorize_every + swagger[APIGATEWAY_AUTHORIZER_KEY]["authorizerResultTtlInSeconds"] = reauthorize_every if function_invoke_role: - swagger[APIGATEWAY_AUTHORIZER_KEY]['authorizerCredentials'] = function_invoke_role + swagger[APIGATEWAY_AUTHORIZER_KEY]["authorizerCredentials"] = function_invoke_role - if self._get_function_payload_type() == 'REQUEST': - swagger[APIGATEWAY_AUTHORIZER_KEY]['identitySource'] = self._get_identity_source() + if self._get_function_payload_type() == "REQUEST": + swagger[APIGATEWAY_AUTHORIZER_KEY]["identitySource"] = self._get_identity_source() # Authorizer Validation Expression is only allowed on COGNITO_USER_POOLS and LAMBDA_TOKEN - is_lambda_token_authorizer = authorizer_type == 'LAMBDA' and self._get_function_payload_type() == 'TOKEN' + is_lambda_token_authorizer = authorizer_type == "LAMBDA" and self._get_function_payload_type() == "TOKEN" - if authorizer_type == 'COGNITO_USER_POOLS' or is_lambda_token_authorizer: + if authorizer_type == "COGNITO_USER_POOLS" or is_lambda_token_authorizer: identity_validation_expression = self._get_identity_validation_expression() if identity_validation_expression: - swagger[APIGATEWAY_AUTHORIZER_KEY]['identityValidationExpression'] = identity_validation_expression + swagger[APIGATEWAY_AUTHORIZER_KEY]["identityValidationExpression"] = identity_validation_expression return swagger def _get_identity_validation_expression(self): - return self.identity and self.identity.get('ValidationExpression') + return self.identity and self.identity.get("ValidationExpression") def _get_identity_source(self): identity_source_headers = [] @@ -274,23 +280,29 @@ def _get_identity_source(self): identity_source_stage_variables = [] identity_source_context = [] - if self.identity.get('Headers'): - identity_source_headers = list(map(lambda h: 'method.request.header.' + h, self.identity.get('Headers'))) + if self.identity.get("Headers"): + identity_source_headers = list(map(lambda h: "method.request.header." + h, self.identity.get("Headers"))) - if self.identity.get('QueryStrings'): - identity_source_query_strings = list(map(lambda qs: 'method.request.querystring.' + qs, - self.identity.get('QueryStrings'))) + if self.identity.get("QueryStrings"): + identity_source_query_strings = list( + map(lambda qs: "method.request.querystring." + qs, self.identity.get("QueryStrings")) + ) - if self.identity.get('StageVariables'): - identity_source_stage_variables = list(map(lambda sv: 'stageVariables.' + sv, - self.identity.get('StageVariables'))) + if self.identity.get("StageVariables"): + identity_source_stage_variables = list( + map(lambda sv: "stageVariables." + sv, self.identity.get("StageVariables")) + ) - if self.identity.get('Context'): - identity_source_context = list(map(lambda c: 'context.' + c, self.identity.get('Context'))) + if self.identity.get("Context"): + identity_source_context = list(map(lambda c: "context." + c, self.identity.get("Context"))) - identity_source_array = (identity_source_headers + identity_source_query_strings + - identity_source_stage_variables + identity_source_context) - identity_source = ', '.join(identity_source_array) + identity_source_array = ( + identity_source_headers + + identity_source_query_strings + + identity_source_stage_variables + + identity_source_context + ) + identity_source = ", ".join(identity_source_array) return identity_source @@ -301,61 +313,61 @@ def _get_swagger_header_name(self): authorizer_type = self._get_type() payload_type = self._get_function_payload_type() - if authorizer_type == 'LAMBDA' and payload_type == 'REQUEST': - return 'Unused' + if authorizer_type == "LAMBDA" and payload_type == "REQUEST": + return "Unused" return self._get_identity_header() def _get_type(self): if self.is_aws_iam_authorizer: - return 'AWS_IAM' + return "AWS_IAM" if self.user_pool_arn: - return 'COGNITO_USER_POOLS' + return "COGNITO_USER_POOLS" - return 'LAMBDA' + return "LAMBDA" def _get_identity_header(self): - if not self.identity or not self.identity.get('Header'): - return 'Authorization' + if not self.identity or not self.identity.get("Header"): + return "Authorization" - return self.identity.get('Header') + return self.identity.get("Header") def _get_reauthorize_every(self): if not self.identity: return None - return self.identity.get('ReauthorizeEvery') + return self.identity.get("ReauthorizeEvery") def _get_function_invoke_role(self): - if not self.function_invoke_role or self.function_invoke_role == 'NONE': + if not self.function_invoke_role or self.function_invoke_role == "NONE": return None return self.function_invoke_role def _get_swagger_authtype(self): authorizer_type = self._get_type() - if authorizer_type == 'AWS_IAM': - return 'awsSigv4' + if authorizer_type == "AWS_IAM": + return "awsSigv4" - if authorizer_type == 'COGNITO_USER_POOLS': - return 'cognito_user_pools' + if authorizer_type == "COGNITO_USER_POOLS": + return "cognito_user_pools" - return 'custom' + return "custom" def _get_function_payload_type(self): - return 'TOKEN' if not self.function_payload_type else self.function_payload_type + return "TOKEN" if not self.function_payload_type else self.function_payload_type def _get_swagger_authorizer_type(self): authorizer_type = self._get_type() - if authorizer_type == 'COGNITO_USER_POOLS': - return 'cognito_user_pools' + if authorizer_type == "COGNITO_USER_POOLS": + return "cognito_user_pools" payload_type = self._get_function_payload_type() - if payload_type == 'REQUEST': - return 'request' + if payload_type == "REQUEST": + return "request" - if payload_type == 'TOKEN': - return 'token' + if payload_type == "TOKEN": + return "token" diff --git a/samtranslator/model/apigatewayv2.py b/samtranslator/model/apigatewayv2.py index 5d8ade5332..93f0fdec6d 100644 --- a/samtranslator/model/apigatewayv2.py +++ b/samtranslator/model/apigatewayv2.py @@ -5,44 +5,47 @@ class ApiGatewayV2HttpApi(Resource): - resource_type = 'AWS::ApiGatewayV2::Api' + resource_type = "AWS::ApiGatewayV2::Api" property_types = { - 'Body': PropertyType(False, is_type(dict)), - 'BodyS3Location': PropertyType(False, is_type(dict)), - 'Description': PropertyType(False, is_str()), - 'FailOnWarnings': PropertyType(False, is_type(bool)), - 'BasePath': PropertyType(False, is_str()), - 'Tags': PropertyType(False, list_of(is_type(dict))), - 'CorsConfiguration': PropertyType(False, is_type(dict)) + "Body": PropertyType(False, is_type(dict)), + "BodyS3Location": PropertyType(False, is_type(dict)), + "Description": PropertyType(False, is_str()), + "FailOnWarnings": PropertyType(False, is_type(bool)), + "BasePath": PropertyType(False, is_str()), + "Tags": PropertyType(False, list_of(is_type(dict))), + "CorsConfiguration": PropertyType(False, is_type(dict)), } - runtime_attrs = { - "http_api_id": lambda self: ref(self.logical_id), - } + runtime_attrs = {"http_api_id": lambda self: ref(self.logical_id)} class ApiGatewayV2Stage(Resource): - resource_type = 'AWS::ApiGatewayV2::Stage' + resource_type = "AWS::ApiGatewayV2::Stage" property_types = { - 'AccessLogSettings': PropertyType(False, is_type(dict)), - 'DefaultRouteSettings': PropertyType(False, is_type(dict)), - 'ClientCertificateId': PropertyType(False, is_str()), - 'Description': PropertyType(False, is_str()), - 'ApiId': PropertyType(True, is_str()), - 'StageName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, list_of(is_type(dict))), - 'StageVariables': PropertyType(False, is_type(dict)), - 'AutoDeploy': PropertyType(False, is_type(bool)) + "AccessLogSettings": PropertyType(False, is_type(dict)), + "DefaultRouteSettings": PropertyType(False, is_type(dict)), + "ClientCertificateId": PropertyType(False, is_str()), + "Description": PropertyType(False, is_str()), + "ApiId": PropertyType(True, is_str()), + "StageName": PropertyType(False, one_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, list_of(is_type(dict))), + "StageVariables": PropertyType(False, is_type(dict)), + "AutoDeploy": PropertyType(False, is_type(bool)), } - runtime_attrs = { - "stage_name": lambda self: ref(self.logical_id), - } + runtime_attrs = {"stage_name": lambda self: ref(self.logical_id)} class ApiGatewayV2Authorizer(object): - def __init__(self, api_logical_id=None, name=None, open_id_connect_url=None, - authorization_scopes=[], jwt_configuration={}, id_source=None): + def __init__( + self, + api_logical_id=None, + name=None, + open_id_connect_url=None, + authorization_scopes=[], + jwt_configuration={}, + id_source=None, + ): """ Creates an authorizer for use in V2 Http Apis """ @@ -73,8 +76,8 @@ def generate_openapi(self): "x-amazon-apigateway-authorizer": { "jwtConfiguration": self.jwt_configuration, "identitySource": self.id_source, - "type": "jwt" - } + "type": "jwt", + }, } if self.open_id_connect_url: openapi["x-amazon-apigateway-authorizer"]["openIdConnectUrl"] = self.open_id_connect_url diff --git a/samtranslator/model/cloudformation.py b/samtranslator/model/cloudformation.py index decd2f0400..22b272c8d7 100644 --- a/samtranslator/model/cloudformation.py +++ b/samtranslator/model/cloudformation.py @@ -4,16 +4,14 @@ class NestedStack(Resource): - resource_type = 'AWS::CloudFormation::Stack' + resource_type = "AWS::CloudFormation::Stack" # TODO: support passthrough parameters for stacks (Conditions, etc) property_types = { - 'TemplateURL': PropertyType(True, is_str()), - 'Parameters': PropertyType(False, is_type(dict)), - 'NotificationARNs': PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), - 'Tags': PropertyType(False, list_of(is_type(dict))), - 'TimeoutInMinutes': PropertyType(False, is_type(int)) + "TemplateURL": PropertyType(True, is_str()), + "Parameters": PropertyType(False, is_type(dict)), + "NotificationARNs": PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), + "Tags": PropertyType(False, list_of(is_type(dict))), + "TimeoutInMinutes": PropertyType(False, is_type(int)), } - runtime_attrs = { - "stack_id": lambda self: ref(self.logical_id) - } + runtime_attrs = {"stack_id": lambda self: ref(self.logical_id)} diff --git a/samtranslator/model/codedeploy.py b/samtranslator/model/codedeploy.py index b65a662fc2..34568786e6 100644 --- a/samtranslator/model/codedeploy.py +++ b/samtranslator/model/codedeploy.py @@ -4,28 +4,22 @@ class CodeDeployApplication(Resource): - resource_type = 'AWS::CodeDeploy::Application' - property_types = { - 'ComputePlatform': PropertyType(False, one_of(is_str(), is_type(dict))), - } + resource_type = "AWS::CodeDeploy::Application" + property_types = {"ComputePlatform": PropertyType(False, one_of(is_str(), is_type(dict)))} - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - } + runtime_attrs = {"name": lambda self: ref(self.logical_id)} class CodeDeployDeploymentGroup(Resource): - resource_type = 'AWS::CodeDeploy::DeploymentGroup' + resource_type = "AWS::CodeDeploy::DeploymentGroup" property_types = { - 'AlarmConfiguration': PropertyType(False, is_type(dict)), - 'ApplicationName': PropertyType(True, one_of(is_str(), is_type(dict))), - 'AutoRollbackConfiguration': PropertyType(False, is_type(dict)), - 'DeploymentConfigName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'DeploymentStyle': PropertyType(False, is_type(dict)), - 'ServiceRoleArn': PropertyType(True, one_of(is_str(), is_type(dict))), - 'TriggerConfigurations': PropertyType(False, is_type(list)) + "AlarmConfiguration": PropertyType(False, is_type(dict)), + "ApplicationName": PropertyType(True, one_of(is_str(), is_type(dict))), + "AutoRollbackConfiguration": PropertyType(False, is_type(dict)), + "DeploymentConfigName": PropertyType(False, one_of(is_str(), is_type(dict))), + "DeploymentStyle": PropertyType(False, is_type(dict)), + "ServiceRoleArn": PropertyType(True, one_of(is_str(), is_type(dict))), + "TriggerConfigurations": PropertyType(False, is_type(list)), } - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - } + runtime_attrs = {"name": lambda self: ref(self.logical_id)} diff --git a/samtranslator/model/cognito.py b/samtranslator/model/cognito.py index 633d5a2e6b..a78cfae1bd 100644 --- a/samtranslator/model/cognito.py +++ b/samtranslator/model/cognito.py @@ -4,32 +4,32 @@ class CognitoUserPool(Resource): - resource_type = 'AWS::Cognito::UserPool' + resource_type = "AWS::Cognito::UserPool" property_types = { - 'AdminCreateUserConfig': PropertyType(False, is_type(dict)), - 'AliasAttributes': PropertyType(False, list_of(is_str())), - 'AutoVerifiedAttributes': PropertyType(False, list_of(is_str())), - 'DeviceConfiguration': PropertyType(False, is_type(dict)), - 'EmailConfiguration': PropertyType(False, is_type(dict)), - 'EmailVerificationMessage': PropertyType(False, is_str()), - 'EmailVerificationSubject': PropertyType(False, is_str()), - 'LambdaConfig': PropertyType(False, is_type(dict)), - 'MfaConfiguration': PropertyType(False, is_str()), - 'Policies': PropertyType(False, is_type(dict)), - 'Schema': PropertyType(False, list_of(dict)), - 'SmsAuthenticationMessage': PropertyType(False, is_str()), - 'SmsConfiguration': PropertyType(False, list_of(dict)), - 'SmsVerificationMessage': PropertyType(False, is_str()), - 'UsernameAttributes': PropertyType(False, list_of(is_str())), - 'UserPoolAddOns': PropertyType(False, list_of(dict)), - 'UserPoolName': PropertyType(False, is_str()), - 'UserPoolTags': PropertyType(False, is_str()), - 'VerificationMessageTemplate': PropertyType(False, is_type(dict)) + "AdminCreateUserConfig": PropertyType(False, is_type(dict)), + "AliasAttributes": PropertyType(False, list_of(is_str())), + "AutoVerifiedAttributes": PropertyType(False, list_of(is_str())), + "DeviceConfiguration": PropertyType(False, is_type(dict)), + "EmailConfiguration": PropertyType(False, is_type(dict)), + "EmailVerificationMessage": PropertyType(False, is_str()), + "EmailVerificationSubject": PropertyType(False, is_str()), + "LambdaConfig": PropertyType(False, is_type(dict)), + "MfaConfiguration": PropertyType(False, is_str()), + "Policies": PropertyType(False, is_type(dict)), + "Schema": PropertyType(False, list_of(dict)), + "SmsAuthenticationMessage": PropertyType(False, is_str()), + "SmsConfiguration": PropertyType(False, list_of(dict)), + "SmsVerificationMessage": PropertyType(False, is_str()), + "UsernameAttributes": PropertyType(False, list_of(is_str())), + "UserPoolAddOns": PropertyType(False, list_of(dict)), + "UserPoolName": PropertyType(False, is_str()), + "UserPoolTags": PropertyType(False, is_str()), + "VerificationMessageTemplate": PropertyType(False, is_type(dict)), } runtime_attrs = { "name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn"), "provider_name": lambda self: fnGetAtt(self.logical_id, "ProviderName"), - "provider_url": lambda self: fnGetAtt(self.logical_id, "ProviderURL") + "provider_url": lambda self: fnGetAtt(self.logical_id, "ProviderURL"), } diff --git a/samtranslator/model/dynamodb.py b/samtranslator/model/dynamodb.py index 6b09c99a08..f9d38344f9 100644 --- a/samtranslator/model/dynamodb.py +++ b/samtranslator/model/dynamodb.py @@ -4,22 +4,22 @@ class DynamoDBTable(Resource): - resource_type = 'AWS::DynamoDB::Table' + resource_type = "AWS::DynamoDB::Table" property_types = { - 'AttributeDefinitions': PropertyType(True, list_of(is_type(dict))), - 'GlobalSecondaryIndexes': PropertyType(False, list_of(is_type(dict))), - 'KeySchema': PropertyType(False, list_of(is_type(dict))), - 'LocalSecondaryIndexes': PropertyType(False, list_of(is_type(dict))), - 'ProvisionedThroughput': PropertyType(False, dict_of(is_str(), one_of(is_type(int), is_type(dict)))), - 'StreamSpecification': PropertyType(False, is_type(dict)), - 'TableName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, list_of(is_type(dict))), - 'SSESpecification': PropertyType(False, is_type(dict)), - 'BillingMode': PropertyType(False, is_str()) + "AttributeDefinitions": PropertyType(True, list_of(is_type(dict))), + "GlobalSecondaryIndexes": PropertyType(False, list_of(is_type(dict))), + "KeySchema": PropertyType(False, list_of(is_type(dict))), + "LocalSecondaryIndexes": PropertyType(False, list_of(is_type(dict))), + "ProvisionedThroughput": PropertyType(False, dict_of(is_str(), one_of(is_type(int), is_type(dict)))), + "StreamSpecification": PropertyType(False, is_type(dict)), + "TableName": PropertyType(False, one_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, list_of(is_type(dict))), + "SSESpecification": PropertyType(False, is_type(dict)), + "BillingMode": PropertyType(False, is_str()), } runtime_attrs = { "name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn"), - "stream_arn": lambda self: fnGetAtt(self.logical_id, "StreamArn") + "stream_arn": lambda self: fnGetAtt(self.logical_id, "StreamArn"), } diff --git a/samtranslator/model/events.py b/samtranslator/model/events.py index d5774446b0..b9c4e05610 100644 --- a/samtranslator/model/events.py +++ b/samtranslator/model/events.py @@ -4,19 +4,16 @@ class EventsRule(Resource): - resource_type = 'AWS::Events::Rule' + resource_type = "AWS::Events::Rule" property_types = { - 'Description': PropertyType(False, is_str()), - 'EventBusName': PropertyType(False, is_str()), - 'EventPattern': PropertyType(False, is_type(dict)), - 'Name': PropertyType(False, is_str()), - 'RoleArn': PropertyType(False, is_str()), - 'ScheduleExpression': PropertyType(False, is_str()), - 'State': PropertyType(False, is_str()), - 'Targets': PropertyType(False, list_of(is_type(dict))) + "Description": PropertyType(False, is_str()), + "EventBusName": PropertyType(False, is_str()), + "EventPattern": PropertyType(False, is_type(dict)), + "Name": PropertyType(False, is_str()), + "RoleArn": PropertyType(False, is_str()), + "ScheduleExpression": PropertyType(False, is_str()), + "State": PropertyType(False, is_str()), + "Targets": PropertyType(False, list_of(is_type(dict))), } - runtime_attrs = { - "rule_id": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + runtime_attrs = {"rule_id": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} diff --git a/samtranslator/model/eventsources/cloudwatchlogs.py b/samtranslator/model/eventsources/cloudwatchlogs.py index 1765291466..8adc324391 100644 --- a/samtranslator/model/eventsources/cloudwatchlogs.py +++ b/samtranslator/model/eventsources/cloudwatchlogs.py @@ -8,12 +8,10 @@ class CloudWatchLogs(PushEventSource): """CloudWatch Logs event source for SAM Functions.""" - resource_type = 'CloudWatchLogs' - principal = 'logs.amazonaws.com' - property_types = { - 'LogGroupName': PropertyType(True, is_str()), - 'FilterPattern': PropertyType(True, is_str()) - } + + resource_type = "CloudWatchLogs" + principal = "logs.amazonaws.com" + property_types = {"LogGroupName": PropertyType(True, is_str()), "FilterPattern": PropertyType(True, is_str())} def to_cloudformation(self, **kwargs): """Returns the CloudWatch Logs Subscription Filter and Lambda Permission to which this CloudWatch Logs event source @@ -23,7 +21,7 @@ def to_cloudformation(self, **kwargs): :returns: a list of vanilla CloudFormation Resources, to which this push event expands :rtype: list """ - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") @@ -39,15 +37,17 @@ def get_source_arn(self): resource = "log-group:${__LogGroupName__}:*" partition = ArnGenerator.get_partition_name() - return fnSub(ArnGenerator.generate_arn(partition=partition, service='logs', resource=resource), - {'__LogGroupName__': self.LogGroupName}) + return fnSub( + ArnGenerator.generate_arn(partition=partition, service="logs", resource=resource), + {"__LogGroupName__": self.LogGroupName}, + ) def get_subscription_filter(self, function, permission): subscription_filter = SubscriptionFilter(self.logical_id, depends_on=[permission.logical_id]) subscription_filter.LogGroupName = self.LogGroupName subscription_filter.FilterPattern = self.FilterPattern subscription_filter.DestinationArn = function.get_runtime_attr("arn") - if 'Condition' in function.resource_attributes: - subscription_filter.set_resource_attribute('Condition', function.resource_attributes['Condition']) + if "Condition" in function.resource_attributes: + subscription_filter.set_resource_attribute("Condition", function.resource_attributes["Condition"]) return subscription_filter diff --git a/samtranslator/model/eventsources/pull.py b/samtranslator/model/eventsources/pull.py index c9886c74e8..2ccb40d32e 100644 --- a/samtranslator/model/eventsources/pull.py +++ b/samtranslator/model/eventsources/pull.py @@ -16,19 +16,20 @@ class PullEventSource(ResourceMacro): :cvar str policy_arn: The ARN of the AWS managed role policy corresponding to this pull event source """ + resource_type = None property_types = { - 'Stream': PropertyType(False, is_str()), - 'Queue': PropertyType(False, is_str()), - 'BatchSize': PropertyType(False, is_type(int)), - 'StartingPosition': PropertyType(False, is_str()), - 'Enabled': PropertyType(False, is_type(bool)), - 'MaximumBatchingWindowInSeconds': PropertyType(False, is_type(int)), - 'MaximumRetryAttempts': PropertyType(False, is_type(int)), - 'BisectBatchOnFunctionError': PropertyType(False, is_type(bool)), - 'MaximumRecordAgeInSeconds': PropertyType(False, is_type(int)), - 'DestinationConfig': PropertyType(False, is_type(dict)), - 'ParallelizationFactor': PropertyType(False, is_type(int)) + "Stream": PropertyType(False, is_str()), + "Queue": PropertyType(False, is_str()), + "BatchSize": PropertyType(False, is_type(int)), + "StartingPosition": PropertyType(False, is_str()), + "Enabled": PropertyType(False, is_type(bool)), + "MaximumBatchingWindowInSeconds": PropertyType(False, is_type(int)), + "MaximumRetryAttempts": PropertyType(False, is_type(int)), + "BisectBatchOnFunctionError": PropertyType(False, is_type(bool)), + "MaximumRecordAgeInSeconds": PropertyType(False, is_type(int)), + "DestinationConfig": PropertyType(False, is_type(dict)), + "ParallelizationFactor": PropertyType(False, is_type(int)), } def get_policy_arn(self): @@ -42,7 +43,7 @@ def to_cloudformation(self, **kwargs): :returns: a list of vanilla CloudFormation Resources, to which this pull event expands :rtype: list """ - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") @@ -60,11 +61,11 @@ def to_cloudformation(self, **kwargs): if not self.Stream and not self.Queue: raise InvalidEventException( - self.relative_id, "No Queue (for SQS) or Stream (for Kinesis or DynamoDB) provided.") + self.relative_id, "No Queue (for SQS) or Stream (for Kinesis or DynamoDB) provided." + ) if self.Stream and not self.StartingPosition: - raise InvalidEventException( - self.relative_id, "StartingPosition is required for Kinesis and DynamoDB.") + raise InvalidEventException(self.relative_id, "StartingPosition is required for Kinesis and DynamoDB.") lambda_eventsourcemapping.FunctionName = function_name_or_arn lambda_eventsourcemapping.EventSourceArn = self.Stream or self.Queue @@ -80,32 +81,35 @@ def to_cloudformation(self, **kwargs): destination_config_policy = None if self.DestinationConfig: # `Type` property is for sam to attach the right policies - destination_type = self.DestinationConfig.get('OnFailure').get('Type') + destination_type = self.DestinationConfig.get("OnFailure").get("Type") # SAM attaches the policies for SQS or SNS only if 'Type' is given if destination_type: # the values 'SQS' and 'SNS' are allowed. No intrinsics are allowed - if destination_type not in ['SQS', 'SNS']: + if destination_type not in ["SQS", "SNS"]: raise InvalidEventException(self.logical_id, "The only valid values for 'Type' are 'SQS' and 'SNS'") - if self.DestinationConfig.get('OnFailure') is None: - raise InvalidEventException(self.logical_id, "'OnFailure' is a required field for " - "'DestinationConfig'") - if destination_type == 'SQS': - queue_arn = self.DestinationConfig.get('OnFailure').get('Destination') - destination_config_policy = IAMRolePolicies().sqs_send_message_role_policy(queue_arn, - self.logical_id) + if self.DestinationConfig.get("OnFailure") is None: + raise InvalidEventException( + self.logical_id, "'OnFailure' is a required field for " "'DestinationConfig'" + ) + if destination_type == "SQS": + queue_arn = self.DestinationConfig.get("OnFailure").get("Destination") + destination_config_policy = IAMRolePolicies().sqs_send_message_role_policy( + queue_arn, self.logical_id + ) elif destination_type == "SNS": - sns_topic_arn = self.DestinationConfig.get('OnFailure').get('Destination') - destination_config_policy = IAMRolePolicies(). sns_publish_role_policy(sns_topic_arn, - self.logical_id) + sns_topic_arn = self.DestinationConfig.get("OnFailure").get("Destination") + destination_config_policy = IAMRolePolicies().sns_publish_role_policy( + sns_topic_arn, self.logical_id + ) lambda_eventsourcemapping.DestinationConfig = self.DestinationConfig - if 'Condition' in function.resource_attributes: - lambda_eventsourcemapping.set_resource_attribute('Condition', function.resource_attributes['Condition']) + if "Condition" in function.resource_attributes: + lambda_eventsourcemapping.set_resource_attribute("Condition", function.resource_attributes["Condition"]) - if 'role' in kwargs: - self._link_policy(kwargs['role'], destination_config_policy) + if "role" in kwargs: + self._link_policy(kwargs["role"], destination_config_policy) return resources @@ -125,29 +129,32 @@ def _link_policy(self, role, destination_config_policy=None): role.Policies.append(destination_config_policy) if role.Policies and destination_config_policy not in role.Policies: # do not add the policy if the same policy document is already present - if not destination_config_policy.get('PolicyDocument') in [d['PolicyDocument'] for d in role.Policies]: + if not destination_config_policy.get("PolicyDocument") in [d["PolicyDocument"] for d in role.Policies]: role.Policies.append(destination_config_policy) class Kinesis(PullEventSource): """Kinesis event source.""" - resource_type = 'Kinesis' + + resource_type = "Kinesis" def get_policy_arn(self): - return ArnGenerator.generate_aws_managed_policy_arn('service-role/AWSLambdaKinesisExecutionRole') + return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaKinesisExecutionRole") class DynamoDB(PullEventSource): """DynamoDB Streams event source.""" - resource_type = 'DynamoDB' + + resource_type = "DynamoDB" def get_policy_arn(self): - return ArnGenerator.generate_aws_managed_policy_arn('service-role/AWSLambdaDynamoDBExecutionRole') + return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaDynamoDBExecutionRole") class SQS(PullEventSource): """SQS Queue event source.""" - resource_type = 'SQS' + + resource_type = "SQS" def get_policy_arn(self): - return ArnGenerator.generate_aws_managed_policy_arn('service-role/AWSLambdaSQSQueueExecutionRole') + return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaSQSQueueExecutionRole") diff --git a/samtranslator/model/eventsources/push.py b/samtranslator/model/eventsources/push.py index 7f0e7e8271..82b7ddeb14 100644 --- a/samtranslator/model/eventsources/push.py +++ b/samtranslator/model/eventsources/push.py @@ -20,7 +20,7 @@ from samtranslator.swagger.swagger import SwaggerEditor from samtranslator.open_api.open_api import OpenApiEditor -CONDITION = 'Condition' +CONDITION = "Condition" REQUEST_PARAMETER_PROPERTIES = ["Required", "Caching"] @@ -45,10 +45,12 @@ class PushEventSource(ResourceMacro): :cvar str principal: The AWS service principal of the source service. """ + principal = None def _construct_permission( - self, function, source_arn=None, source_account=None, suffix="", event_source_token=None, prefix=None): + self, function, source_arn=None, source_account=None, suffix="", event_source_token=None, prefix=None + ): """Constructs the Lambda Permission resource allowing the source service to invoke the function this event source triggers. @@ -58,19 +60,20 @@ def _construct_permission( if prefix is None: prefix = self.logical_id if suffix.isalnum(): - permission_logical_id = prefix + 'Permission' + suffix + permission_logical_id = prefix + "Permission" + suffix else: - generator = logical_id_generator.LogicalIdGenerator(prefix + 'Permission', suffix) + generator = logical_id_generator.LogicalIdGenerator(prefix + "Permission", suffix) permission_logical_id = generator.gen() - lambda_permission = LambdaPermission(permission_logical_id, - attributes=function.get_passthrough_resource_attributes()) + lambda_permission = LambdaPermission( + permission_logical_id, attributes=function.get_passthrough_resource_attributes() + ) try: # Name will not be available for Alias resources function_name_or_arn = function.get_runtime_attr("name") except NotImplementedError: function_name_or_arn = function.get_runtime_attr("arn") - lambda_permission.Action = 'lambda:InvokeFunction' + lambda_permission.Action = "lambda:InvokeFunction" lambda_permission.FunctionName = function_name_or_arn lambda_permission.Principal = self.principal lambda_permission.SourceArn = source_arn @@ -82,14 +85,15 @@ def _construct_permission( class Schedule(PushEventSource): """Scheduled executions for SAM Functions.""" - resource_type = 'Schedule' - principal = 'events.amazonaws.com' + + resource_type = "Schedule" + principal = "events.amazonaws.com" property_types = { - 'Schedule': PropertyType(True, is_str()), - 'Input': PropertyType(False, is_str()), - 'Enabled': PropertyType(False, is_type(bool)), - 'Name': PropertyType(False, is_str()), - 'Description': PropertyType(False, is_str()) + "Schedule": PropertyType(True, is_str()), + "Input": PropertyType(False, is_str()), + "Enabled": PropertyType(False, is_type(bool)), + "Name": PropertyType(False, is_str()), + "Description": PropertyType(False, is_str()), } def to_cloudformation(self, **kwargs): @@ -99,7 +103,7 @@ def to_cloudformation(self, **kwargs): :returns: a list of vanilla CloudFormation Resources, to which this Schedule event expands :rtype: list """ - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") @@ -129,25 +133,23 @@ def _construct_target(self, function): :returns: the Target property :rtype: dict """ - target = { - 'Arn': function.get_runtime_attr("arn"), - 'Id': self.logical_id + 'LambdaTarget' - } + target = {"Arn": function.get_runtime_attr("arn"), "Id": self.logical_id + "LambdaTarget"} if self.Input is not None: - target['Input'] = self.Input + target["Input"] = self.Input return target class CloudWatchEvent(PushEventSource): """CloudWatch Events/EventBridge event source for SAM Functions.""" - resource_type = 'CloudWatchEvent' - principal = 'events.amazonaws.com' + + resource_type = "CloudWatchEvent" + principal = "events.amazonaws.com" property_types = { - 'EventBusName': PropertyType(False, is_str()), - 'Pattern': PropertyType(False, is_type(dict)), - 'Input': PropertyType(False, is_str()), - 'InputPath': PropertyType(False, is_str()) + "EventBusName": PropertyType(False, is_str()), + "Pattern": PropertyType(False, is_type(dict)), + "Input": PropertyType(False, is_str()), + "InputPath": PropertyType(False, is_str()), } def to_cloudformation(self, **kwargs): @@ -158,7 +160,7 @@ def to_cloudformation(self, **kwargs): :returns: a list of vanilla CloudFormation Resources, to which this CloudWatch Events/EventBridge event expands :rtype: list """ - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") @@ -185,41 +187,37 @@ def _construct_target(self, function): :returns: the Target property :rtype: dict """ - target = { - 'Arn': function.get_runtime_attr("arn"), - 'Id': self.logical_id + 'LambdaTarget' - } + target = {"Arn": function.get_runtime_attr("arn"), "Id": self.logical_id + "LambdaTarget"} if self.Input is not None: - target['Input'] = self.Input + target["Input"] = self.Input if self.InputPath is not None: - target['InputPath'] = self.InputPath + target["InputPath"] = self.InputPath return target class EventBridgeRule(CloudWatchEvent): """EventBridge Rule event source for SAM Functions.""" - resource_type = 'EventBridgeRule' + + resource_type = "EventBridgeRule" class S3(PushEventSource): """S3 bucket event source for SAM Functions.""" - resource_type = 'S3' - principal = 's3.amazonaws.com' + + resource_type = "S3" + principal = "s3.amazonaws.com" property_types = { - 'Bucket': PropertyType(True, is_str()), - 'Events': PropertyType(True, one_of(is_str(), list_of(is_str()))), - 'Filter': PropertyType(False, dict_of(is_str(), is_str())) + "Bucket": PropertyType(True, is_str()), + "Events": PropertyType(True, one_of(is_str(), list_of(is_str()))), + "Filter": PropertyType(False, dict_of(is_str(), is_str())), } def resources_to_link(self, resources): - if isinstance(self.Bucket, dict) and 'Ref' in self.Bucket: - bucket_id = self.Bucket['Ref'] + if isinstance(self.Bucket, dict) and "Ref" in self.Bucket: + bucket_id = self.Bucket["Ref"] if bucket_id in resources: - return { - 'bucket': resources[bucket_id], - 'bucket_id': bucket_id - } + return {"bucket": resources[bucket_id], "bucket_id": bucket_id} raise InvalidEventException(self.relative_id, "S3 events must reference an S3 bucket in the same template.") def to_cloudformation(self, **kwargs): @@ -229,23 +227,23 @@ def to_cloudformation(self, **kwargs): :returns: a list of vanilla CloudFormation Resources, to which this S3 event expands :rtype: list """ - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") - if 'bucket' not in kwargs or kwargs['bucket'] is None: + if "bucket" not in kwargs or kwargs["bucket"] is None: raise TypeError("Missing required keyword argument: bucket") - if 'bucket_id' not in kwargs or kwargs['bucket_id'] is None: + if "bucket_id" not in kwargs or kwargs["bucket_id"] is None: raise TypeError("Missing required keyword argument: bucket_id") - bucket = kwargs['bucket'] - bucket_id = kwargs['bucket_id'] + bucket = kwargs["bucket"] + bucket_id = kwargs["bucket_id"] resources = [] - source_account = ref('AWS::AccountId') + source_account = ref("AWS::AccountId") permission = self._construct_permission(function, source_account=source_account) if CONDITION in permission.resource_attributes: self._depend_on_lambda_permissions_using_tag(bucket, permission) @@ -304,33 +302,28 @@ def _depend_on_lambda_permissions_using_tag(self, bucket, permission): dependency, so CloudFormation will automatically wait once it reaches that function, the same as if you were using a DependsOn. """ - properties = bucket.get('Properties', None) + properties = bucket.get("Properties", None) if properties is None: properties = {} - bucket['Properties'] = properties - tags = properties.get('Tags', None) + bucket["Properties"] = properties + tags = properties.get("Tags", None) if tags is None: tags = [] - properties['Tags'] = tags + properties["Tags"] = tags dep_tag = { - 'sam:ConditionalDependsOn:' + permission.logical_id: { - 'Fn::If': [ - permission.resource_attributes[CONDITION], - ref(permission.logical_id), - 'no dependency' - ] + "sam:ConditionalDependsOn:" + + permission.logical_id: { + "Fn::If": [permission.resource_attributes[CONDITION], ref(permission.logical_id), "no dependency"] } } - properties['Tags'] = tags + get_tag_list(dep_tag) + properties["Tags"] = tags + get_tag_list(dep_tag) return bucket def _inject_notification_configuration(self, function, bucket): - base_event_mapping = { - 'Function': function.get_runtime_attr("arn") - } + base_event_mapping = {"Function": function.get_runtime_attr("arn")} if self.Filter is not None: - base_event_mapping['Filter'] = self.Filter + base_event_mapping["Filter"] = self.Filter event_types = self.Events if isinstance(self.Events, string_types): @@ -340,25 +333,25 @@ def _inject_notification_configuration(self, function, bucket): for event_type in event_types: lambda_event = copy.deepcopy(base_event_mapping) - lambda_event['Event'] = event_type + lambda_event["Event"] = event_type if CONDITION in function.resource_attributes: lambda_event = make_conditional(function.resource_attributes[CONDITION], lambda_event) event_mappings.append(lambda_event) - properties = bucket.get('Properties', None) + properties = bucket.get("Properties", None) if properties is None: properties = {} - bucket['Properties'] = properties + bucket["Properties"] = properties - notification_config = properties.get('NotificationConfiguration', None) + notification_config = properties.get("NotificationConfiguration", None) if notification_config is None: notification_config = {} - properties['NotificationConfiguration'] = notification_config + properties["NotificationConfiguration"] = notification_config - lambda_notifications = notification_config.get('LambdaConfigurations', None) + lambda_notifications = notification_config.get("LambdaConfigurations", None) if lambda_notifications is None: lambda_notifications = [] - notification_config['LambdaConfigurations'] = lambda_notifications + notification_config["LambdaConfigurations"] = lambda_notifications for event_mapping in event_mappings: if event_mapping not in lambda_notifications: @@ -368,13 +361,14 @@ def _inject_notification_configuration(self, function, bucket): class SNS(PushEventSource): """SNS topic event source for SAM Functions.""" - resource_type = 'SNS' - principal = 'sns.amazonaws.com' + + resource_type = "SNS" + principal = "sns.amazonaws.com" property_types = { - 'Topic': PropertyType(True, is_str()), - 'Region': PropertyType(False, is_str()), - 'FilterPolicy': PropertyType(False, dict_of(is_str(), list_of(one_of(is_str(), is_type(dict))))), - 'SqsSubscription': PropertyType(False, one_of(is_type(bool), is_type(dict))) + "Topic": PropertyType(True, is_str()), + "Region": PropertyType(False, is_str()), + "FilterPolicy": PropertyType(False, dict_of(is_str(), list_of(one_of(is_str(), is_type(dict))))), + "SqsSubscription": PropertyType(False, one_of(is_type(bool), is_type(dict))), } def to_cloudformation(self, **kwargs): @@ -384,8 +378,8 @@ def to_cloudformation(self, **kwargs): :returns: a list of vanilla CloudFormation Resources, to which this SNS event expands :rtype: list """ - function = kwargs.get('function') - role = kwargs.get('role') + function = kwargs.get("function") + role = kwargs.get("role") if not function: raise TypeError("Missing required keyword argument: function") @@ -393,8 +387,12 @@ def to_cloudformation(self, **kwargs): # SNS -> Lambda if not self.SqsSubscription: subscription = self._inject_subscription( - 'lambda', function.get_runtime_attr("arn"), - self.Topic, self.Region, self.FilterPolicy, function.resource_attributes + "lambda", + function.get_runtime_attr("arn"), + self.Topic, + self.Region, + self.FilterPolicy, + function.resource_attributes, ) return [self._construct_permission(function, source_arn=self.Topic), subscription] @@ -402,13 +400,12 @@ def to_cloudformation(self, **kwargs): if isinstance(self.SqsSubscription, bool): resources = [] queue = self._inject_sqs_queue() - queue_arn = queue.get_runtime_attr('arn') - queue_url = queue.get_runtime_attr('queue_url') + queue_arn = queue.get_runtime_attr("arn") + queue_url = queue.get_runtime_attr("queue_url") queue_policy = self._inject_sqs_queue_policy(self.Topic, queue_arn, queue_url) subscription = self._inject_subscription( - 'sqs', queue_arn, - self.Topic, self.Region, self.FilterPolicy, function.resource_attributes + "sqs", queue_arn, self.Topic, self.Region, self.FilterPolicy, function.resource_attributes ) event_source = self._inject_sqs_event_source_mapping(function, role, queue_arn) @@ -420,20 +417,18 @@ def to_cloudformation(self, **kwargs): # SNS -> SQS(Existing) -> Lambda resources = [] - queue_arn = self.SqsSubscription.get('QueueArn', None) - queue_url = self.SqsSubscription.get('QueueUrl', None) + queue_arn = self.SqsSubscription.get("QueueArn", None) + queue_url = self.SqsSubscription.get("QueueUrl", None) if not queue_arn or not queue_url: - raise InvalidEventException( - self.relative_id, "No QueueARN or QueueURL provided.") + raise InvalidEventException(self.relative_id, "No QueueARN or QueueURL provided.") - queue_policy_logical_id = self.SqsSubscription.get('QueuePolicyLogicalId', None) - batch_size = self.SqsSubscription.get('BatchSize', None) - enabled = self.SqsSubscription.get('Enabled', None) + queue_policy_logical_id = self.SqsSubscription.get("QueuePolicyLogicalId", None) + batch_size = self.SqsSubscription.get("BatchSize", None) + enabled = self.SqsSubscription.get("Enabled", None) queue_policy = self._inject_sqs_queue_policy(self.Topic, queue_arn, queue_url, queue_policy_logical_id) subscription = self._inject_subscription( - 'sqs', queue_arn, - self.Topic, self.Region, self.FilterPolicy, function.resource_attributes + "sqs", queue_arn, self.Topic, self.Region, self.FilterPolicy, function.resource_attributes ) event_source = self._inject_sqs_event_source_mapping(function, role, queue_arn, batch_size, enabled) @@ -458,38 +453,36 @@ def _inject_subscription(self, protocol, endpoint, topic, region, filterPolicy, return subscription def _inject_sqs_queue(self): - return SQSQueue(self.logical_id + 'Queue') + return SQSQueue(self.logical_id + "Queue") def _inject_sqs_event_source_mapping(self, function, role, queue_arn, batch_size=None, enabled=None): - event_source = SQS(self.logical_id + 'EventSourceMapping') + event_source = SQS(self.logical_id + "EventSourceMapping") event_source.Queue = queue_arn event_source.BatchSize = batch_size or 10 event_source.Enabled = enabled or True return event_source.to_cloudformation(function=function, role=role) def _inject_sqs_queue_policy(self, topic_arn, queue_arn, queue_url, logical_id=None): - policy = SQSQueuePolicy(logical_id or self.logical_id + 'QueuePolicy') - policy.PolicyDocument = SQSQueuePolicies.sns_topic_send_message_role_policy( - topic_arn, queue_arn - ) + policy = SQSQueuePolicy(logical_id or self.logical_id + "QueuePolicy") + policy.PolicyDocument = SQSQueuePolicies.sns_topic_send_message_role_policy(topic_arn, queue_arn) policy.Queues = [queue_url] return policy class Api(PushEventSource): """Api method event source for SAM Functions.""" - resource_type = 'Api' - principal = 'apigateway.amazonaws.com' - property_types = { - 'Path': PropertyType(True, is_str()), - 'Method': PropertyType(True, is_str()), + resource_type = "Api" + principal = "apigateway.amazonaws.com" + property_types = { + "Path": PropertyType(True, is_str()), + "Method": PropertyType(True, is_str()), # Api Event sources must "always" be paired with a Serverless::Api - 'RestApiId': PropertyType(True, is_str()), - 'Stage': PropertyType(False, is_str()), - 'Auth': PropertyType(False, is_type(dict)), - 'RequestModel': PropertyType(False, is_type(dict)), - 'RequestParameters': PropertyType(False, is_type(list)) + "RestApiId": PropertyType(True, is_str()), + "Stage": PropertyType(False, is_str()), + "Auth": PropertyType(False, is_type(dict)), + "RequestModel": PropertyType(False, is_type(dict)), + "RequestParameters": PropertyType(False, is_type(list)), } def resources_to_link(self, resources): @@ -513,9 +506,11 @@ def resources_to_link(self, resources): explicit_api = None if isinstance(rest_api_id, string_types): - if rest_api_id in resources \ - and "Properties" in resources[rest_api_id] \ - and "StageName" in resources[rest_api_id]["Properties"]: + if ( + rest_api_id in resources + and "Properties" in resources[rest_api_id] + and "StageName" in resources[rest_api_id]["Properties"] + ): explicit_api = resources[rest_api_id]["Properties"] permitted_stage = explicit_api["StageName"] @@ -523,22 +518,19 @@ def resources_to_link(self, resources): # Stage could be a intrinsic, in which case leave the suffix to default value if isinstance(permitted_stage, string_types): if not permitted_stage: - raise InvalidResourceException(rest_api_id, 'StageName cannot be empty.') + raise InvalidResourceException(rest_api_id, "StageName cannot be empty.") stage_suffix = permitted_stage else: stage_suffix = "Stage" else: # RestApiId is a string, not an intrinsic, but we did not find a valid API resource for this ID - raise InvalidEventException(self.relative_id, "RestApiId property of Api event must reference a valid " - "resource in the same template.") + raise InvalidEventException( + self.relative_id, + "RestApiId property of Api event must reference a valid " "resource in the same template.", + ) - return { - 'explicit_api': explicit_api, - 'explicit_api_stage': { - 'suffix': stage_suffix - } - } + return {"explicit_api": explicit_api, "explicit_api_stage": {"suffix": stage_suffix}} def to_cloudformation(self, **kwargs): """If the Api event source has a RestApi property, then simply return the Lambda Permission resource allowing @@ -552,7 +544,7 @@ def to_cloudformation(self, **kwargs): """ resources = [] - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") @@ -563,7 +555,7 @@ def to_cloudformation(self, **kwargs): resources.extend(self._get_permissions(kwargs)) - explicit_api = kwargs['explicit_api'] + explicit_api = kwargs["explicit_api"] self._add_swagger_integration(explicit_api, function) return resources @@ -576,8 +568,8 @@ def _get_permissions(self, resources_to_link): # all stages for an API are given permission permitted_stage = "*" suffix = "Prod" - if 'explicit_api_stage' in resources_to_link: - suffix = resources_to_link['explicit_api_stage']['suffix'] + if "explicit_api_stage" in resources_to_link: + suffix = resources_to_link["explicit_api_stage"]["suffix"] self.Stage = suffix permissions.append(self._get_permission(resources_to_link, permitted_stage, suffix)) @@ -587,23 +579,25 @@ def _get_permission(self, resources_to_link, stage, suffix): # It turns out that APIGW doesn't like trailing slashes in paths (#665) # and removes as a part of their behaviour, but this isn't documented. # The regex removes the tailing slash to ensure the permission works as intended - path = re.sub(r'^(.+)/$', r'\1', self.Path) + path = re.sub(r"^(.+)/$", r"\1", self.Path) if not stage or not suffix: raise RuntimeError("Could not add permission to lambda function.") path = SwaggerEditor.get_path_without_trailing_slash(path) - method = '*' if self.Method.lower() == 'any' else self.Method.upper() + method = "*" if self.Method.lower() == "any" else self.Method.upper() api_id = self.RestApiId # RestApiId can be a simple string or intrinsic function like !Ref. Using Fn::Sub will handle both cases - resource = '${__ApiId__}/' + '${__Stage__}/' + method + path + resource = "${__ApiId__}/" + "${__Stage__}/" + method + path partition = ArnGenerator.get_partition_name() - source_arn = fnSub(ArnGenerator.generate_arn(partition=partition, service='execute-api', resource=resource), - {"__ApiId__": api_id, "__Stage__": stage}) + source_arn = fnSub( + ArnGenerator.generate_arn(partition=partition, service="execute-api", resource=resource), + {"__ApiId__": api_id, "__Stage__": stage}, + ) - return self._construct_permission(resources_to_link['function'], source_arn=source_arn, suffix=suffix) + return self._construct_permission(resources_to_link["function"], source_arn=source_arn, suffix=suffix) def _add_swagger_integration(self, api, function): """Adds the path and method for this Api event source to the Swagger body for the provided RestApi. @@ -614,10 +608,15 @@ def _add_swagger_integration(self, api, function): if swagger_body is None: return - function_arn = function.get_runtime_attr('arn') + function_arn = function.get_runtime_attr("arn") partition = ArnGenerator.get_partition_name() - uri = fnSub('arn:' + partition + ':apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/' + - make_shorthand(function_arn) + '/invocations') + uri = fnSub( + "arn:" + + partition + + ":apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/" + + make_shorthand(function_arn) + + "/invocations" + ) editor = SwaggerEditor(swagger_body) if api.get("__MANAGE_SWAGGER"): @@ -626,97 +625,110 @@ def _add_swagger_integration(self, api, function): raise InvalidEventException( self.relative_id, 'API method "{method}" defined multiple times for path "{path}".'.format( - method=self.Method, path=self.Path)) + method=self.Method, path=self.Path + ), + ) condition = None if CONDITION in function.resource_attributes: condition = function.resource_attributes[CONDITION] - editor.add_lambda_integration(self.Path, self.Method, uri, self.Auth, api.get('Auth'), condition=condition) + editor.add_lambda_integration(self.Path, self.Method, uri, self.Auth, api.get("Auth"), condition=condition) if self.Auth: - method_authorizer = self.Auth.get('Authorizer') - api_auth = api.get('Auth') + method_authorizer = self.Auth.get("Authorizer") + api_auth = api.get("Auth") if method_authorizer: - api_authorizers = api_auth and api_auth.get('Authorizers') + api_authorizers = api_auth and api_auth.get("Authorizers") - if method_authorizer != 'AWS_IAM': - if method_authorizer != 'NONE' and not api_authorizers: + if method_authorizer != "AWS_IAM": + if method_authorizer != "NONE" and not api_authorizers: raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] ' - 'because the related API does not define any Authorizers.'.format( - authorizer=method_authorizer, method=self.Method, path=self.Path)) + "Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] " + "because the related API does not define any Authorizers.".format( + authorizer=method_authorizer, method=self.Method, path=self.Path + ), + ) - if method_authorizer != 'NONE' and not api_authorizers.get(method_authorizer): + if method_authorizer != "NONE" and not api_authorizers.get(method_authorizer): raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] ' - 'because it wasn\'t defined in the API\'s Authorizers.'.format( - authorizer=method_authorizer, method=self.Method, path=self.Path)) - - if method_authorizer == 'NONE': - if not api_auth or not api_auth.get('DefaultAuthorizer'): + "Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] " + "because it wasn't defined in the API's Authorizers.".format( + authorizer=method_authorizer, method=self.Method, path=self.Path + ), + ) + + if method_authorizer == "NONE": + if not api_auth or not api_auth.get("DefaultAuthorizer"): raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer on API method [{method}] for path [{path}] because \'NONE\' ' - 'is only a valid value when a DefaultAuthorizer on the API is specified.'.format( - method=self.Method, path=self.Path)) + "Unable to set Authorizer on API method [{method}] for path [{path}] because 'NONE' " + "is only a valid value when a DefaultAuthorizer on the API is specified.".format( + method=self.Method, path=self.Path + ), + ) if self.Auth.get("AuthorizationScopes") and not isinstance(self.Auth.get("AuthorizationScopes"), list): raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer on API method [{method}] for path [{path}] because ' - '\'AuthorizationScopes\' must be a list of strings.'.format(method=self.Method, - path=self.Path)) + "Unable to set Authorizer on API method [{method}] for path [{path}] because " + "'AuthorizationScopes' must be a list of strings.".format(method=self.Method, path=self.Path), + ) - apikey_required_setting = self.Auth.get('ApiKeyRequired') + apikey_required_setting = self.Auth.get("ApiKeyRequired") apikey_required_setting_is_false = apikey_required_setting is not None and not apikey_required_setting - if apikey_required_setting_is_false and not api_auth.get('ApiKeyRequired'): + if apikey_required_setting_is_false and not api_auth.get("ApiKeyRequired"): raise InvalidEventException( self.relative_id, - 'Unable to set ApiKeyRequired [False] on API method [{method}] for path [{path}] ' - 'because the related API does not specify any ApiKeyRequired.'.format( - method=self.Method, path=self.Path)) + "Unable to set ApiKeyRequired [False] on API method [{method}] for path [{path}] " + "because the related API does not specify any ApiKeyRequired.".format( + method=self.Method, path=self.Path + ), + ) if method_authorizer or apikey_required_setting is not None: if editor.has_path(self.Path): editor.add_auth_to_method(api=api, path=self.Path, method_name=self.Method, auth=self.Auth) - if self.Auth.get('ResourcePolicy'): - resource_policy = self.Auth.get('ResourcePolicy') - editor.add_resource_policy(resource_policy=resource_policy, - path=self.Path, api_id=self.RestApiId.get('Ref'), stage=self.Stage) + if self.Auth.get("ResourcePolicy"): + resource_policy = self.Auth.get("ResourcePolicy") + editor.add_resource_policy( + resource_policy=resource_policy, path=self.Path, api_id=self.RestApiId.get("Ref"), stage=self.Stage + ) if api.get("__MANAGE_SWAGGER"): if self.RequestModel: - method_model = self.RequestModel.get('Model') + method_model = self.RequestModel.get("Model") if method_model: - api_models = api.get('Models') + api_models = api.get("Models") if not api_models: raise InvalidEventException( self.relative_id, - 'Unable to set RequestModel [{model}] on API method [{method}] for path [{path}] ' - 'because the related API does not define any Models.'.format( - model=method_model, method=self.Method, path=self.Path)) + "Unable to set RequestModel [{model}] on API method [{method}] for path [{path}] " + "because the related API does not define any Models.".format( + model=method_model, method=self.Method, path=self.Path + ), + ) if not api_models.get(method_model): raise InvalidEventException( self.relative_id, - 'Unable to set RequestModel [{model}] on API method [{method}] for path [{path}] ' - 'because it wasn\'t defined in the API\'s Models.'.format( - model=method_model, method=self.Method, path=self.Path)) + "Unable to set RequestModel [{model}] on API method [{method}] for path [{path}] " + "because it wasn't defined in the API's Models.".format( + model=method_model, method=self.Method, path=self.Path + ), + ) - editor.add_request_model_to_method(path=self.Path, method_name=self.Method, - request_model=self.RequestModel) + editor.add_request_model_to_method( + path=self.Path, method_name=self.Method, request_model=self.RequestModel + ) if self.RequestParameters: - default_value = { - 'Required': False, - 'Caching': False - } + default_value = {"Required": False, "Caching": False} parameters = [] for parameter in self.RequestParameters: @@ -725,60 +737,65 @@ def _add_swagger_integration(self, api, function): parameter_name, parameter_value = next(iter(parameter.items())) - if not re.match('method\.request\.(querystring|path|header)\.', parameter_name): + if not re.match("method\.request\.(querystring|path|header)\.", parameter_name): raise InvalidEventException( self.relative_id, "Invalid value for 'RequestParameters' property. Keys must be in the format " "'method.request.[querystring|path|header].{value}', " - "e.g 'method.request.header.Authorization'.") + "e.g 'method.request.header.Authorization'.", + ) - if not isinstance(parameter_value, dict) or not all(key in REQUEST_PARAMETER_PROPERTIES - for key in parameter_value.keys()): + if not isinstance(parameter_value, dict) or not all( + key in REQUEST_PARAMETER_PROPERTIES for key in parameter_value.keys() + ): raise InvalidEventException( self.relative_id, "Invalid value for 'RequestParameters' property. Values must be an object, " - "e.g { Required: true, Caching: false }") + "e.g { Required: true, Caching: false }", + ) settings = default_value.copy() settings.update(parameter_value) - settings.update({'Name': parameter_name}) + settings.update({"Name": parameter_name}) parameters.append(settings) elif isinstance(parameter, string_types): - if not re.match('method\.request\.(querystring|path|header)\.', parameter): + if not re.match("method\.request\.(querystring|path|header)\.", parameter): raise InvalidEventException( self.relative_id, "Invalid value for 'RequestParameters' property. Keys must be in the format " "'method.request.[querystring|path|header].{value}', " - "e.g 'method.request.header.Authorization'.") + "e.g 'method.request.header.Authorization'.", + ) settings = default_value.copy() - settings.update({'Name': parameter}) + settings.update({"Name": parameter}) parameters.append(settings) else: raise InvalidEventException( - self.relative_id, "Invalid value for 'RequestParameters' property. " - "Property must be either a string or an object") + self.relative_id, + "Invalid value for 'RequestParameters' property. " + "Property must be either a string or an object", + ) - editor.add_request_parameters_to_method(path=self.Path, method_name=self.Method, - request_parameters=parameters) + editor.add_request_parameters_to_method( + path=self.Path, method_name=self.Method, request_parameters=parameters + ) api["DefinitionBody"] = editor.swagger class AlexaSkill(PushEventSource): - resource_type = 'AlexaSkill' - principal = 'alexa-appkit.amazon.com' + resource_type = "AlexaSkill" + principal = "alexa-appkit.amazon.com" - property_types = { - 'SkillId': PropertyType(False, is_str()), - } + property_types = {"SkillId": PropertyType(False, is_str())} def to_cloudformation(self, **kwargs): - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") @@ -790,31 +807,29 @@ def to_cloudformation(self, **kwargs): class IoTRule(PushEventSource): - resource_type = 'IoTRule' - principal = 'iot.amazonaws.com' + resource_type = "IoTRule" + principal = "iot.amazonaws.com" - property_types = { - 'Sql': PropertyType(True, is_str()), - 'AwsIotSqlVersion': PropertyType(False, is_str()) - } + property_types = {"Sql": PropertyType(True, is_str()), "AwsIotSqlVersion": PropertyType(False, is_str())} def to_cloudformation(self, **kwargs): - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") resources = [] - resource = 'rule/${RuleName}' + resource = "rule/${RuleName}" partition = ArnGenerator.get_partition_name() - source_arn = fnSub(ArnGenerator.generate_arn(partition=partition, service='iot', resource=resource), - {'RuleName': ref(self.logical_id)}) - source_account = fnSub('${AWS::AccountId}') + source_arn = fnSub( + ArnGenerator.generate_arn(partition=partition, service="iot", resource=resource), + {"RuleName": ref(self.logical_id)}, + ) + source_account = fnSub("${AWS::AccountId}") - resources.append(self._construct_permission(function, source_arn=source_arn, - source_account=source_account)) + resources.append(self._construct_permission(function, source_arn=source_arn, source_account=source_account)) resources.append(self._construct_iot_rule(function)) return resources @@ -823,19 +838,13 @@ def _construct_iot_rule(self, function): rule = IotTopicRule(self.logical_id) payload = { - 'Sql': self.Sql, - 'RuleDisabled': False, - 'Actions': [ - { - 'Lambda': { - 'FunctionArn': function.get_runtime_attr("arn") - } - } - ] + "Sql": self.Sql, + "RuleDisabled": False, + "Actions": [{"Lambda": {"FunctionArn": function.get_runtime_attr("arn")}}], } if self.AwsIotSqlVersion: - payload['AwsIotSqlVersion'] = self.AwsIotSqlVersion + payload["AwsIotSqlVersion"] = self.AwsIotSqlVersion rule.TopicRulePayload = payload if CONDITION in function.resource_attributes: @@ -845,46 +854,43 @@ def _construct_iot_rule(self, function): class Cognito(PushEventSource): - resource_type = 'Cognito' - principal = 'cognito-idp.amazonaws.com' + resource_type = "Cognito" + principal = "cognito-idp.amazonaws.com" property_types = { - 'UserPool': PropertyType(True, is_str()), - 'Trigger': PropertyType(True, one_of(is_str(), list_of(is_str()))) + "UserPool": PropertyType(True, is_str()), + "Trigger": PropertyType(True, one_of(is_str(), list_of(is_str()))), } def resources_to_link(self, resources): - if isinstance(self.UserPool, dict) and 'Ref' in self.UserPool: - userpool_id = self.UserPool['Ref'] + if isinstance(self.UserPool, dict) and "Ref" in self.UserPool: + userpool_id = self.UserPool["Ref"] if userpool_id in resources: - return { - 'userpool': resources[userpool_id], - 'userpool_id': userpool_id - } + return {"userpool": resources[userpool_id], "userpool_id": userpool_id} raise InvalidEventException( - self.relative_id, - "Cognito events must reference a Cognito UserPool in the same template.") + self.relative_id, "Cognito events must reference a Cognito UserPool in the same template." + ) def to_cloudformation(self, **kwargs): - function = kwargs.get('function') + function = kwargs.get("function") if not function: raise TypeError("Missing required keyword argument: function") - if 'userpool' not in kwargs or kwargs['userpool'] is None: + if "userpool" not in kwargs or kwargs["userpool"] is None: raise TypeError("Missing required keyword argument: userpool") - if 'userpool_id' not in kwargs or kwargs['userpool_id'] is None: + if "userpool_id" not in kwargs or kwargs["userpool_id"] is None: raise TypeError("Missing required keyword argument: userpool_id") - userpool = kwargs['userpool'] - userpool_id = kwargs['userpool_id'] + userpool = kwargs["userpool"] + userpool_id = kwargs["userpool_id"] resources = [] - source_arn = fnGetAtt(userpool_id, 'Arn') + source_arn = fnGetAtt(userpool_id, "Arn") resources.append( - self._construct_permission( - function, source_arn=source_arn, prefix=function.logical_id + "Cognito")) + self._construct_permission(function, source_arn=source_arn, prefix=function.logical_id + "Cognito") + ) self._inject_lambda_config(function, userpool) resources.append(CognitoUserPool.from_dict(userpool_id, userpool)) @@ -897,37 +903,37 @@ def _inject_lambda_config(self, function, userpool): # TODO can these be conditional? - properties = userpool.get('Properties', None) + properties = userpool.get("Properties", None) if properties is None: properties = {} - userpool['Properties'] = properties + userpool["Properties"] = properties - lambda_config = properties.get('LambdaConfig', None) + lambda_config = properties.get("LambdaConfig", None) if lambda_config is None: lambda_config = {} - properties['LambdaConfig'] = lambda_config + properties["LambdaConfig"] = lambda_config for event_trigger in event_triggers: if event_trigger not in lambda_config: lambda_config[event_trigger] = function.get_runtime_attr("arn") else: raise InvalidEventException( - self.relative_id, - 'Cognito trigger "{trigger}" defined multiple times.'.format( - trigger=self.Trigger)) + self.relative_id, 'Cognito trigger "{trigger}" defined multiple times.'.format(trigger=self.Trigger) + ) return userpool class HttpApi(PushEventSource): """Api method event source for SAM Functions.""" - resource_type = 'HttpApi' - principal = 'apigateway.amazonaws.com' + + resource_type = "HttpApi" + principal = "apigateway.amazonaws.com" property_types = { - 'Path': PropertyType(False, is_str()), - 'Method': PropertyType(False, is_str()), - 'ApiId': PropertyType(False, is_str()), - 'Stage': PropertyType(False, is_str()), - 'Auth': PropertyType(False, is_type(dict)) + "Path": PropertyType(False, is_str()), + "Method": PropertyType(False, is_str()), + "ApiId": PropertyType(False, is_str()), + "Stage": PropertyType(False, is_str()), + "Auth": PropertyType(False, is_type(dict)), } def resources_to_link(self, resources): @@ -942,9 +948,7 @@ def resources_to_link(self, resources): explicit_api = resources[api_id].get("Properties") - return { - 'explicit_api': explicit_api - } + return {"explicit_api": explicit_api} def to_cloudformation(self, **kwargs): """If the Api event source has a RestApi property, then simply return the Lambda Permission resource allowing @@ -958,7 +962,7 @@ def to_cloudformation(self, **kwargs): """ resources = [] - function = kwargs.get('function') + function = kwargs.get("function") if self.Method is not None: # Convert to lower case so that user can specify either GET or get @@ -966,7 +970,7 @@ def to_cloudformation(self, **kwargs): resources.extend(self._get_permissions(kwargs)) - explicit_api = kwargs['explicit_api'] + explicit_api = kwargs["explicit_api"] self._add_openapi_integration(explicit_api, function, explicit_api.get("__MANAGE_SWAGGER")) return resources @@ -986,7 +990,7 @@ def _get_permission(self, resources_to_link, stage): # It turns out that APIGW doesn't like trailing slashes in paths (#665) # and removes as a part of their behaviour, but this isn't documented. # The regex removes the tailing slash to ensure the permission works as intended - path = re.sub(r'^(.+)/$', r'\1', self.Path) + path = re.sub(r"^(.+)/$", r"\1", self.Path) editor = None if resources_to_link["explicit_api"].get("DefinitionBody"): @@ -998,9 +1002,10 @@ def _get_permission(self, resources_to_link, stage): # If this is using the new $default path, keep path blank and add a * permission if path == OpenApiEditor._DEFAULT_PATH: - path = '' - elif (editor and resources_to_link.get("function").logical_id == - editor.get_integration_function_logical_id(OpenApiEditor._DEFAULT_PATH, OpenApiEditor._X_ANY_METHOD)): + path = "" + elif editor and resources_to_link.get("function").logical_id == editor.get_integration_function_logical_id( + OpenApiEditor._DEFAULT_PATH, OpenApiEditor._X_ANY_METHOD + ): # Case where default exists for this function, and so the permissions for that will apply here as well # This can save us several CFN resources (not duplicating permissions) return @@ -1008,19 +1013,21 @@ def _get_permission(self, resources_to_link, stage): path = OpenApiEditor.get_path_without_trailing_slash(path) # Handle case where Method is already the ANY ApiGateway extension - if self.Method.lower() == 'any' or self.Method.lower() == OpenApiEditor._X_ANY_METHOD: - method = '*' + if self.Method.lower() == "any" or self.Method.lower() == OpenApiEditor._X_ANY_METHOD: + method = "*" else: method = self.Method.upper() api_id = self.ApiId # ApiId can be a simple string or intrinsic function like !Ref. Using Fn::Sub will handle both cases - resource = '${__ApiId__}/' + '${__Stage__}/' + method + path - source_arn = fnSub(ArnGenerator.generate_arn(partition="${AWS::Partition}", service='execute-api', - resource=resource), {"__ApiId__": api_id, "__Stage__": stage}) + resource = "${__ApiId__}/" + "${__Stage__}/" + method + path + source_arn = fnSub( + ArnGenerator.generate_arn(partition="${AWS::Partition}", service="execute-api", resource=resource), + {"__ApiId__": api_id, "__Stage__": stage}, + ) - return self._construct_permission(resources_to_link['function'], source_arn=source_arn) + return self._construct_permission(resources_to_link["function"], source_arn=source_arn) def _add_openapi_integration(self, api, function, manage_swagger=False): """Adds the path and method for this Api event source to the OpenApi body for the provided RestApi. @@ -1031,9 +1038,12 @@ def _add_openapi_integration(self, api, function, manage_swagger=False): if open_api_body is None: return - function_arn = function.get_runtime_attr('arn') - uri = fnSub('arn:${AWS::Partition}:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/' + - make_shorthand(function_arn) + '/invocations') + function_arn = function.get_runtime_attr("arn") + uri = fnSub( + "arn:${AWS::Partition}:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/" + + make_shorthand(function_arn) + + "/invocations" + ) editor = OpenApiEditor(open_api_body) @@ -1042,13 +1052,15 @@ def _add_openapi_integration(self, api, function, manage_swagger=False): raise InvalidEventException( self.relative_id, "API method '{method}' defined multiple times for path '{path}'.".format( - method=self.Method, path=self.Path)) + method=self.Method, path=self.Path + ), + ) condition = None if CONDITION in function.resource_attributes: condition = function.resource_attributes[CONDITION] - editor.add_lambda_integration(self.Path, self.Method, uri, self.Auth, api.get('Auth'), condition=condition) + editor.add_lambda_integration(self.Path, self.Method, uri, self.Auth, api.get("Auth"), condition=condition) if self.Auth: self._add_auth_to_openapi_integration(api, editor) api["DefinitionBody"] = editor.openapi @@ -1058,44 +1070,53 @@ def _add_auth_to_openapi_integration(self, api, editor): :param api: api object :param editor: OpenApiEditor object that contains the OpenApi definition """ - method_authorizer = self.Auth.get('Authorizer') - api_auth = api.get('Auth') + method_authorizer = self.Auth.get("Authorizer") + api_auth = api.get("Auth") if not method_authorizer: if api_auth.get("DefaultAuthorizer"): self.Auth["Authorizer"] = method_authorizer = api_auth.get("DefaultAuthorizer") else: # currently, we require either a default auth or auth in the method - raise InvalidEventException(self.relative_id, "'Auth' section requires either " - "an explicit 'Authorizer' set or a 'DefaultAuthorizer' " - "configured on the HttpApi.") + raise InvalidEventException( + self.relative_id, + "'Auth' section requires either " + "an explicit 'Authorizer' set or a 'DefaultAuthorizer' " + "configured on the HttpApi.", + ) # Default auth should already be applied, so apply any other auth here or scope override to default - api_authorizers = api_auth and api_auth.get('Authorizers') + api_authorizers = api_auth and api_auth.get("Authorizers") - if method_authorizer != 'NONE' and not api_authorizers: + if method_authorizer != "NONE" and not api_authorizers: raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] ' - 'because the related API does not define any Authorizers.'.format( - authorizer=method_authorizer, method=self.Method, path=self.Path)) + "Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] " + "because the related API does not define any Authorizers.".format( + authorizer=method_authorizer, method=self.Method, path=self.Path + ), + ) - if method_authorizer != 'NONE' and not api_authorizers.get(method_authorizer): + if method_authorizer != "NONE" and not api_authorizers.get(method_authorizer): raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] ' - 'because it wasn\'t defined in the API\'s Authorizers.'.format( - authorizer=method_authorizer, method=self.Method, path=self.Path)) + "Unable to set Authorizer [{authorizer}] on API method [{method}] for path [{path}] " + "because it wasn't defined in the API's Authorizers.".format( + authorizer=method_authorizer, method=self.Method, path=self.Path + ), + ) - if method_authorizer == 'NONE' and not api_auth.get('DefaultAuthorizer'): + if method_authorizer == "NONE" and not api_auth.get("DefaultAuthorizer"): raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer on API method [{method}] for path [{path}] because \'NONE\' ' - 'is only a valid value when a DefaultAuthorizer on the API is specified.'.format( - method=self.Method, path=self.Path)) + "Unable to set Authorizer on API method [{method}] for path [{path}] because 'NONE' " + "is only a valid value when a DefaultAuthorizer on the API is specified.".format( + method=self.Method, path=self.Path + ), + ) if self.Auth.get("AuthorizationScopes") and not isinstance(self.Auth.get("AuthorizationScopes"), list): raise InvalidEventException( self.relative_id, - 'Unable to set Authorizer on API method [{method}] for path [{path}] because ' - '\'AuthorizationScopes\' must be a list of strings.'.format(method=self.Method, - path=self.Path)) + "Unable to set Authorizer on API method [{method}] for path [{path}] because " + "'AuthorizationScopes' must be a list of strings.".format(method=self.Method, path=self.Path), + ) editor.add_auth_to_method(api=api, path=self.Path, method_name=self.Method, auth=self.Auth) diff --git a/samtranslator/model/exceptions.py b/samtranslator/model/exceptions.py index bca9f8fc90..153954b617 100644 --- a/samtranslator/model/exceptions.py +++ b/samtranslator/model/exceptions.py @@ -5,13 +5,15 @@ class InvalidDocumentException(Exception): message -- explanation of the error causes -- list of errors which caused this document to be invalid """ + def __init__(self, causes): self._causes = sorted(causes) @property def message(self): - return 'Invalid Serverless Application Specification document. Number of errors found: {}.'\ - .format(len(self.causes)) + return "Invalid Serverless Application Specification document. Number of errors found: {}.".format( + len(self.causes) + ) @property def causes(self): @@ -23,6 +25,7 @@ class DuplicateLogicalIdException(Exception): Attributes: message -- explanation of the error """ + def __init__(self, logical_id, duplicate_id, type): self._logical_id = logical_id self._duplicate_id = duplicate_id @@ -30,10 +33,13 @@ def __init__(self, logical_id, duplicate_id, type): @property def message(self): - return 'Transforming resource with id [{logical_id}] attempts to create a new' \ - ' resource with id [{duplicate_id}] and type "{type}". A resource with that id already' \ - ' exists within this template. Please use a different id for that resource.'.format( - logical_id=self._logical_id, type=self._type, duplicate_id=self._duplicate_id) + return ( + "Transforming resource with id [{logical_id}] attempts to create a new" + ' resource with id [{duplicate_id}] and type "{type}". A resource with that id already' + " exists within this template. Please use a different id for that resource.".format( + logical_id=self._logical_id, type=self._type, duplicate_id=self._duplicate_id + ) + ) class InvalidTemplateException(Exception): @@ -57,6 +63,7 @@ class InvalidResourceException(Exception): Attributes: message -- explanation of the error """ + def __init__(self, logical_id, message): self._logical_id = logical_id self._message = message @@ -66,7 +73,7 @@ def __lt__(self, other): @property def message(self): - return 'Resource with id [{}] is invalid. {}'.format(self._logical_id, self._message) + return "Resource with id [{}] is invalid. {}".format(self._logical_id, self._message) class InvalidEventException(Exception): @@ -75,16 +82,17 @@ class InvalidEventException(Exception): Attributes: message -- explanation of the error """ + def __init__(self, event_id, message): self._event_id = event_id self._message = message @property def message(self): - return 'Event with id [{}] is invalid. {}'.format(self._event_id, self._message) + return "Event with id [{}] is invalid. {}".format(self._event_id, self._message) -def prepend(exception, message, end=': '): +def prepend(exception, message, end=": "): """Prepends the first argument (i.e., the exception message) of the a BaseException with the provided message. Useful for reraising exceptions with additional information. @@ -93,6 +101,6 @@ def prepend(exception, message, end=': '): :param str end: the separator to add to the end of the provided message :returns: the exception """ - exception.args = exception.args or ('',) - exception.args = (message + end + exception.args[0], ) + exception.args[1:] + exception.args = exception.args or ("",) + exception.args = (message + end + exception.args[0],) + exception.args[1:] return exception diff --git a/samtranslator/model/function_policies.py b/samtranslator/model/function_policies.py index 843366e388..929d498ee7 100644 --- a/samtranslator/model/function_policies.py +++ b/samtranslator/model/function_policies.py @@ -101,9 +101,11 @@ def _contains_policies(self, resource_properties): :param dict resource_properties: Properties of the resource :return: True if we can process this resource. False, otherwise """ - return resource_properties is not None \ - and isinstance(resource_properties, dict) \ + return ( + resource_properties is not None + and isinstance(resource_properties, dict) and self.POLICIES_PROPERTY_NAME in resource_properties + ) def _get_type(self, policy): """ @@ -147,10 +149,12 @@ def _is_policy_template(self, policy): :return: True, if this is a policy template. False if it is not """ - return self._policy_template_processor is not None and \ - isinstance(policy, dict) and \ - len(policy) == 1 and \ - self._policy_template_processor.has(list(policy.keys())[0]) is True + return ( + self._policy_template_processor is not None + and isinstance(policy, dict) + and len(policy) == 1 + and self._policy_template_processor.has(list(policy.keys())[0]) is True + ) def _get_type_from_intrinsic_if(self, policy): """ @@ -179,14 +183,17 @@ def _get_type_from_intrinsic_if(self, policy): if is_intrinsic_no_value(else_data): return if_data_type - raise InvalidTemplateException("Different policy types within the same Fn::If statement is unsupported. " - "Separate different policy types into different Fn::If statements") + raise InvalidTemplateException( + "Different policy types within the same Fn::If statement is unsupported. " + "Separate different policy types into different Fn::If statements" + ) class PolicyTypes(Enum): """ Enum of different policy types supported by SAM & this plugin """ + MANAGED_POLICY = "managed_policy" POLICY_STATEMENT = "policy_statement" POLICY_TEMPLATE = "policy_template" diff --git a/samtranslator/model/iam.py b/samtranslator/model/iam.py index 1cfd6f60cd..ee27a68d56 100644 --- a/samtranslator/model/iam.py +++ b/samtranslator/model/iam.py @@ -4,45 +4,41 @@ class IAMRole(Resource): - resource_type = 'AWS::IAM::Role' + resource_type = "AWS::IAM::Role" property_types = { - 'AssumeRolePolicyDocument': PropertyType(True, is_type(dict)), - 'ManagedPolicyArns': PropertyType(False, is_type(list)), - 'Path': PropertyType(False, is_str()), - 'Policies': PropertyType(False, is_type(list)), - 'PermissionsBoundary': PropertyType(False, is_str()), - 'Tags': PropertyType(False, list_of(is_type(dict))), - } - - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") + "AssumeRolePolicyDocument": PropertyType(True, is_type(dict)), + "ManagedPolicyArns": PropertyType(False, is_type(list)), + "Path": PropertyType(False, is_str()), + "Policies": PropertyType(False, is_type(list)), + "PermissionsBoundary": PropertyType(False, is_str()), + "Tags": PropertyType(False, list_of(is_type(dict))), } + runtime_attrs = {"name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} -class IAMRolePolicies(): +class IAMRolePolicies: @classmethod def cloud_watch_log_assume_role_policy(cls): document = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': ['sts:AssumeRole'], - 'Effect': 'Allow', - 'Principal': {'Service': ['apigateway.amazonaws.com']} - }] + "Version": "2012-10-17", + "Statement": [ + { + "Action": ["sts:AssumeRole"], + "Effect": "Allow", + "Principal": {"Service": ["apigateway.amazonaws.com"]}, + } + ], } return document @classmethod def lambda_assume_role_policy(cls): document = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': ['sts:AssumeRole'], - 'Effect': 'Allow', - 'Principal': {'Service': ['lambda.amazonaws.com']} - }] + "Version": "2012-10-17", + "Statement": [ + {"Action": ["sts:AssumeRole"], "Effect": "Allow", "Principal": {"Service": ["lambda.amazonaws.com"]}} + ], } return document @@ -53,69 +49,45 @@ def dead_letter_queue_policy(cls, action, resource): :rtype: Dict """ return { - 'PolicyName': 'DeadLetterQueuePolicy', - 'PolicyDocument': { + "PolicyName": "DeadLetterQueuePolicy", + "PolicyDocument": { "Version": "2012-10-17", - "Statement": [{ - "Action": action, - "Resource": resource, - "Effect": "Allow" - }] - } + "Statement": [{"Action": action, "Resource": resource, "Effect": "Allow"}], + }, } @classmethod def sqs_send_message_role_policy(cls, queue_arn, logical_id): document = { - 'PolicyName': logical_id + 'SQSPolicy', - 'PolicyDocument': { - 'Statement': [{ - 'Action': 'sqs:SendMessage', - 'Effect': 'Allow', - 'Resource': queue_arn - }] - } + "PolicyName": logical_id + "SQSPolicy", + "PolicyDocument": {"Statement": [{"Action": "sqs:SendMessage", "Effect": "Allow", "Resource": queue_arn}]}, } return document @classmethod def sns_publish_role_policy(cls, topic_arn, logical_id): document = { - 'PolicyName': logical_id + 'SNSPolicy', - 'PolicyDocument': { - 'Statement': [{ - 'Action': 'sns:publish', - 'Effect': 'Allow', - 'Resource': topic_arn - }] - } + "PolicyName": logical_id + "SNSPolicy", + "PolicyDocument": {"Statement": [{"Action": "sns:publish", "Effect": "Allow", "Resource": topic_arn}]}, } return document @classmethod def event_bus_put_events_role_policy(cls, event_bus_arn, logical_id): document = { - 'PolicyName': logical_id + 'EventBridgePolicy', - 'PolicyDocument': { - 'Statement': [{ - 'Action': 'events:PutEvents', - 'Effect': 'Allow', - 'Resource': event_bus_arn - }] - } + "PolicyName": logical_id + "EventBridgePolicy", + "PolicyDocument": { + "Statement": [{"Action": "events:PutEvents", "Effect": "Allow", "Resource": event_bus_arn}] + }, } return document @classmethod def lambda_invoke_function_role_policy(cls, function_arn, logical_id): document = { - 'PolicyName': logical_id + 'LambdaPolicy', - 'PolicyDocument': { - 'Statement': [{ - 'Action': 'lambda:InvokeFunction', - 'Effect': 'Allow', - 'Resource': function_arn - }] - } + "PolicyName": logical_id + "LambdaPolicy", + "PolicyDocument": { + "Statement": [{"Action": "lambda:InvokeFunction", "Effect": "Allow", "Resource": function_arn}] + }, } return document diff --git a/samtranslator/model/intrinsics.py b/samtranslator/model/intrinsics.py index f3b56c3580..18dccc8ee0 100644 --- a/samtranslator/model/intrinsics.py +++ b/samtranslator/model/intrinsics.py @@ -1,51 +1,41 @@ def fnGetAtt(logical_name, attribute_name): - return {'Fn::GetAtt': [logical_name, attribute_name]} + return {"Fn::GetAtt": [logical_name, attribute_name]} def ref(logical_name): - return {'Ref': logical_name} + return {"Ref": logical_name} def fnJoin(delimiter, values): - return {'Fn::Join': [delimiter, values]} + return {"Fn::Join": [delimiter, values]} def fnSub(string, variables=None): if variables: - return {'Fn::Sub': [string, variables]} - return {'Fn::Sub': string} + return {"Fn::Sub": [string, variables]} + return {"Fn::Sub": string} def fnOr(argument_list): - return {'Fn::Or': argument_list} + return {"Fn::Or": argument_list} def fnAnd(argument_list): - return {'Fn::And': argument_list} + return {"Fn::And": argument_list} -def make_conditional(condition, true_data, false_data={'Ref': 'AWS::NoValue'}): - return { - 'Fn::If': [ - condition, - true_data, - false_data - ] - } +def make_conditional(condition, true_data, false_data={"Ref": "AWS::NoValue"}): + return {"Fn::If": [condition, true_data, false_data]} def make_not_conditional(condition): - return { - 'Fn::Not': [ - {'Condition': condition} - ] - } + return {"Fn::Not": [{"Condition": condition}]} def make_condition_or_list(conditions_list): condition_or_list = [] for condition in conditions_list: - c = {'Condition': condition} + c = {"Condition": condition} condition_or_list.append(c) return condition_or_list @@ -108,7 +98,7 @@ def make_combined_condition(conditions_list, condition_name): new_condition_name = condition_name # If more than 1 new condition is needed, add a number to the end of the name if zero_based_num_conditions > 0: - new_condition_name = '{}{}'.format(condition_name, zero_based_num_conditions) + new_condition_name = "{}{}".format(condition_name, zero_based_num_conditions) zero_based_num_conditions -= 1 new_condition_content = make_or_condition(conditions_list[:max_conditions]) conditions_list = conditions_list[max_conditions:] @@ -132,7 +122,7 @@ def make_shorthand(intrinsic_dict): :raises NotImplementedError: For intrinsic functions that don't support shorthands. """ if "Ref" in intrinsic_dict: - return "${%s}" % intrinsic_dict['Ref'] + return "${%s}" % intrinsic_dict["Ref"] elif "Fn::GetAtt" in intrinsic_dict: return "${%s}" % ".".join(intrinsic_dict["Fn::GetAtt"]) else: @@ -148,9 +138,7 @@ def is_instrinsic(input): :return: True, if yes """ - if input is not None \ - and isinstance(input, dict) \ - and len(input) == 1: + if input is not None and isinstance(input, dict) and len(input) == 1: key = list(input.keys())[0] return key == "Ref" or key == "Condition" or key.startswith("Fn::") diff --git a/samtranslator/model/iot.py b/samtranslator/model/iot.py index 4b9b7bcb8e..49de2d50f2 100644 --- a/samtranslator/model/iot.py +++ b/samtranslator/model/iot.py @@ -4,12 +4,7 @@ class IotTopicRule(Resource): - resource_type = 'AWS::IoT::TopicRule' - property_types = { - 'TopicRulePayload': PropertyType(False, is_type(dict)) - } + resource_type = "AWS::IoT::TopicRule" + property_types = {"TopicRulePayload": PropertyType(False, is_type(dict))} - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + runtime_attrs = {"name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} diff --git a/samtranslator/model/lambda_.py b/samtranslator/model/lambda_.py index 3c643fb39f..e7b8e806a9 100644 --- a/samtranslator/model/lambda_.py +++ b/samtranslator/model/lambda_.py @@ -4,102 +4,95 @@ class LambdaFunction(Resource): - resource_type = 'AWS::Lambda::Function' + resource_type = "AWS::Lambda::Function" property_types = { - 'Code': PropertyType(True, is_type(dict)), - 'DeadLetterConfig': PropertyType(False, is_type(dict)), - 'Description': PropertyType(False, is_str()), - 'FunctionName': PropertyType(False, is_str()), - 'Handler': PropertyType(True, is_str()), - 'MemorySize': PropertyType(False, is_type(int)), - 'Role': PropertyType(False, is_str()), - 'Runtime': PropertyType(False, is_str()), - 'Timeout': PropertyType(False, is_type(int)), - 'VpcConfig': PropertyType(False, is_type(dict)), - 'Environment': PropertyType(False, is_type(dict)), - 'Tags': PropertyType(False, list_of(is_type(dict))), - 'TracingConfig': PropertyType(False, is_type(dict)), - 'KmsKeyArn': PropertyType(False, one_of(is_type(dict), is_str())), - 'Layers': PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), - 'ReservedConcurrentExecutions': PropertyType(False, any_type()) + "Code": PropertyType(True, is_type(dict)), + "DeadLetterConfig": PropertyType(False, is_type(dict)), + "Description": PropertyType(False, is_str()), + "FunctionName": PropertyType(False, is_str()), + "Handler": PropertyType(True, is_str()), + "MemorySize": PropertyType(False, is_type(int)), + "Role": PropertyType(False, is_str()), + "Runtime": PropertyType(False, is_str()), + "Timeout": PropertyType(False, is_type(int)), + "VpcConfig": PropertyType(False, is_type(dict)), + "Environment": PropertyType(False, is_type(dict)), + "Tags": PropertyType(False, list_of(is_type(dict))), + "TracingConfig": PropertyType(False, is_type(dict)), + "KmsKeyArn": PropertyType(False, one_of(is_type(dict), is_str())), + "Layers": PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), + "ReservedConcurrentExecutions": PropertyType(False, any_type()), } - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + runtime_attrs = {"name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} class LambdaVersion(Resource): - resource_type = 'AWS::Lambda::Version' + resource_type = "AWS::Lambda::Version" property_types = { - 'CodeSha256': PropertyType(False, is_str()), - 'Description': PropertyType(False, is_str()), - 'FunctionName': PropertyType(True, one_of(is_str(), is_type(dict))) + "CodeSha256": PropertyType(False, is_str()), + "Description": PropertyType(False, is_str()), + "FunctionName": PropertyType(True, one_of(is_str(), is_type(dict))), } runtime_attrs = { "arn": lambda self: ref(self.logical_id), - "version": lambda self: fnGetAtt(self.logical_id, "Version") + "version": lambda self: fnGetAtt(self.logical_id, "Version"), } class LambdaAlias(Resource): - resource_type = 'AWS::Lambda::Alias' + resource_type = "AWS::Lambda::Alias" property_types = { - 'Description': PropertyType(False, is_str()), - 'Name': PropertyType(False, is_str()), - 'FunctionName': PropertyType(True, one_of(is_str(), is_type(dict))), - 'FunctionVersion': PropertyType(True, one_of(is_str(), is_type(dict))), - 'ProvisionedConcurrencyConfig': PropertyType(False, is_type(dict)) + "Description": PropertyType(False, is_str()), + "Name": PropertyType(False, is_str()), + "FunctionName": PropertyType(True, one_of(is_str(), is_type(dict))), + "FunctionVersion": PropertyType(True, one_of(is_str(), is_type(dict))), + "ProvisionedConcurrencyConfig": PropertyType(False, is_type(dict)), } - runtime_attrs = { - "arn": lambda self: ref(self.logical_id) - } + runtime_attrs = {"arn": lambda self: ref(self.logical_id)} class LambdaEventSourceMapping(Resource): - resource_type = 'AWS::Lambda::EventSourceMapping' + resource_type = "AWS::Lambda::EventSourceMapping" property_types = { - 'BatchSize': PropertyType(False, is_type(int)), - 'Enabled': PropertyType(False, is_type(bool)), - 'EventSourceArn': PropertyType(True, is_str()), - 'FunctionName': PropertyType(True, is_str()), - 'MaximumBatchingWindowInSeconds': PropertyType(False, is_type(int)), - 'MaximumRetryAttempts': PropertyType(False, is_type(int)), - 'BisectBatchOnFunctionError': PropertyType(False, is_type(bool)), - 'MaximumRecordAgeInSeconds': PropertyType(False, is_type(int)), - 'DestinationConfig': PropertyType(False, is_type(dict)), - 'ParallelizationFactor': PropertyType(False, is_type(int)), - 'StartingPosition': PropertyType(False, is_str()) + "BatchSize": PropertyType(False, is_type(int)), + "Enabled": PropertyType(False, is_type(bool)), + "EventSourceArn": PropertyType(True, is_str()), + "FunctionName": PropertyType(True, is_str()), + "MaximumBatchingWindowInSeconds": PropertyType(False, is_type(int)), + "MaximumRetryAttempts": PropertyType(False, is_type(int)), + "BisectBatchOnFunctionError": PropertyType(False, is_type(bool)), + "MaximumRecordAgeInSeconds": PropertyType(False, is_type(int)), + "DestinationConfig": PropertyType(False, is_type(dict)), + "ParallelizationFactor": PropertyType(False, is_type(int)), + "StartingPosition": PropertyType(False, is_str()), } - runtime_attrs = { - "name": lambda self: ref(self.logical_id) - } + runtime_attrs = {"name": lambda self: ref(self.logical_id)} class LambdaPermission(Resource): - resource_type = 'AWS::Lambda::Permission' + resource_type = "AWS::Lambda::Permission" property_types = { - 'Action': PropertyType(True, is_str()), - 'FunctionName': PropertyType(True, is_str()), - 'Principal': PropertyType(True, is_str()), - 'SourceAccount': PropertyType(False, is_str()), - 'SourceArn': PropertyType(False, is_str()), - 'EventSourceToken': PropertyType(False, is_str()) + "Action": PropertyType(True, is_str()), + "FunctionName": PropertyType(True, is_str()), + "Principal": PropertyType(True, is_str()), + "SourceAccount": PropertyType(False, is_str()), + "SourceArn": PropertyType(False, is_str()), + "EventSourceToken": PropertyType(False, is_str()), } class LambdaEventInvokeConfig(Resource): - resource_type = 'AWS::Lambda::EventInvokeConfig' + resource_type = "AWS::Lambda::EventInvokeConfig" property_types = { - 'DestinationConfig': PropertyType(False, is_type(dict)), - 'FunctionName': PropertyType(True, is_str()), - 'MaximumEventAgeInSeconds': PropertyType(False, is_type(int)), - 'MaximumRetryAttempts': PropertyType(False, is_type(int)), - 'Qualifier': PropertyType(True, is_str()) + "DestinationConfig": PropertyType(False, is_type(dict)), + "FunctionName": PropertyType(True, is_str()), + "MaximumEventAgeInSeconds": PropertyType(False, is_type(int)), + "MaximumRetryAttempts": PropertyType(False, is_type(int)), + "Qualifier": PropertyType(True, is_str()), } @@ -107,16 +100,13 @@ class LambdaLayerVersion(Resource): """ Lambda layer version resource """ - resource_type = 'AWS::Lambda::LayerVersion' + resource_type = "AWS::Lambda::LayerVersion" property_types = { - 'Content': PropertyType(True, is_type(dict)), - 'Description': PropertyType(False, is_str()), - 'LayerName': PropertyType(False, is_str()), - 'CompatibleRuntimes': PropertyType(False, list_of(is_str())), - 'LicenseInfo': PropertyType(False, is_str()) + "Content": PropertyType(True, is_type(dict)), + "Description": PropertyType(False, is_str()), + "LayerName": PropertyType(False, is_str()), + "CompatibleRuntimes": PropertyType(False, list_of(is_str())), + "LicenseInfo": PropertyType(False, is_str()), } - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + runtime_attrs = {"name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} diff --git a/samtranslator/model/log.py b/samtranslator/model/log.py index 853a0f817f..39c41540b5 100644 --- a/samtranslator/model/log.py +++ b/samtranslator/model/log.py @@ -4,14 +4,11 @@ class SubscriptionFilter(Resource): - resource_type = 'AWS::Logs::SubscriptionFilter' + resource_type = "AWS::Logs::SubscriptionFilter" property_types = { - 'LogGroupName': PropertyType(True, is_str()), - 'FilterPattern': PropertyType(True, is_str()), - 'DestinationArn': PropertyType(True, is_str()) + "LogGroupName": PropertyType(True, is_str()), + "FilterPattern": PropertyType(True, is_str()), + "DestinationArn": PropertyType(True, is_str()), } - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + runtime_attrs = {"name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} diff --git a/samtranslator/model/naming.py b/samtranslator/model/naming.py index 2679c2ddcf..ab73f5eac4 100644 --- a/samtranslator/model/naming.py +++ b/samtranslator/model/naming.py @@ -1,4 +1,3 @@ - class GeneratedLogicalId(object): """ Class to generate LogicalIDs for various scenarios. SAM generates LogicalIds for new resources based on code diff --git a/samtranslator/model/preferences/deployment_preference.py b/samtranslator/model/preferences/deployment_preference.py index 354f773cb6..0f2c544c66 100644 --- a/samtranslator/model/preferences/deployment_preference.py +++ b/samtranslator/model/preferences/deployment_preference.py @@ -23,9 +23,10 @@ :param trigger_configurations: Information about triggers associated with the deployment group. Duplicates are not allowed. """ -DeploymentPreferenceTuple = namedtuple('DeploymentPreferenceTuple', - ['deployment_type', 'pre_traffic_hook', 'post_traffic_hook', 'alarms', - 'enabled', 'role', 'trigger_configurations']) +DeploymentPreferenceTuple = namedtuple( + "DeploymentPreferenceTuple", + ["deployment_type", "pre_traffic_hook", "post_traffic_hook", "alarms", "enabled", "role", "trigger_configurations"], +) class DeploymentPreference(DeploymentPreferenceTuple): @@ -42,23 +43,25 @@ def from_dict(cls, logical_id, deployment_preference_dict): :param deployment_preference_dict: the dict object taken from the SAM template :return: """ - enabled = deployment_preference_dict.get('Enabled', True) + enabled = deployment_preference_dict.get("Enabled", True) if not enabled: return DeploymentPreference(None, None, None, None, False, None, None) - if 'Type' not in deployment_preference_dict: + if "Type" not in deployment_preference_dict: raise InvalidResourceException(logical_id, "'DeploymentPreference' is missing required Property 'Type'") - deployment_type = deployment_preference_dict['Type'] - hooks = deployment_preference_dict.get('Hooks', dict()) + deployment_type = deployment_preference_dict["Type"] + hooks = deployment_preference_dict.get("Hooks", dict()) if not isinstance(hooks, dict): - raise InvalidResourceException(logical_id, - "'Hooks' property of 'DeploymentPreference' must be a dictionary") + raise InvalidResourceException( + logical_id, "'Hooks' property of 'DeploymentPreference' must be a dictionary" + ) - pre_traffic_hook = hooks.get('PreTraffic', None) - post_traffic_hook = hooks.get('PostTraffic', None) - alarms = deployment_preference_dict.get('Alarms', None) - role = deployment_preference_dict.get('Role', None) - trigger_configurations = deployment_preference_dict.get('TriggerConfigurations', None) - return DeploymentPreference(deployment_type, pre_traffic_hook, post_traffic_hook, alarms, enabled, role, - trigger_configurations) + pre_traffic_hook = hooks.get("PreTraffic", None) + post_traffic_hook = hooks.get("PostTraffic", None) + alarms = deployment_preference_dict.get("Alarms", None) + role = deployment_preference_dict.get("Role", None) + trigger_configurations = deployment_preference_dict.get("TriggerConfigurations", None) + return DeploymentPreference( + deployment_type, pre_traffic_hook, post_traffic_hook, alarms, enabled, role, trigger_configurations + ) diff --git a/samtranslator/model/preferences/deployment_preference_collection.py b/samtranslator/model/preferences/deployment_preference_collection.py index 767d8ced08..2f0bc35ac1 100644 --- a/samtranslator/model/preferences/deployment_preference_collection.py +++ b/samtranslator/model/preferences/deployment_preference_collection.py @@ -7,18 +7,19 @@ from samtranslator.translator.arn_generator import ArnGenerator import copy -CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID = 'CodeDeployServiceRole' -CODEDEPLOY_APPLICATION_LOGICAL_ID = 'ServerlessDeploymentApplication' -CODEDEPLOY_PREDEFINED_CONFIGURATIONS_LIST = ["Canary10Percent5Minutes", - "Canary10Percent10Minutes", - "Canary10Percent15Minutes", - "Canary10Percent30Minutes", - "Linear10PercentEvery1Minute", - "Linear10PercentEvery2Minutes", - "Linear10PercentEvery3Minutes", - "Linear10PercentEvery10Minutes", - "AllAtOnce" - ] +CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID = "CodeDeployServiceRole" +CODEDEPLOY_APPLICATION_LOGICAL_ID = "ServerlessDeploymentApplication" +CODEDEPLOY_PREDEFINED_CONFIGURATIONS_LIST = [ + "Canary10Percent5Minutes", + "Canary10Percent10Minutes", + "Canary10Percent15Minutes", + "Canary10Percent30Minutes", + "Linear10PercentEvery1Minute", + "Linear10PercentEvery2Minutes", + "Linear10PercentEvery3Minutes", + "Linear10PercentEvery10Minutes", + "AllAtOnce", +] class DeploymentPreferenceCollection(object): @@ -49,8 +50,11 @@ def add(self, logical_id, deployment_preference_dict): :param deployment_preference_dict: the input SAM template deployment preference mapping """ if logical_id in self._resource_preferences: - raise ValueError("logical_id {logical_id} previously added to this deployment_preference_collection".format( - logical_id=logical_id)) + raise ValueError( + "logical_id {logical_id} previously added to this deployment_preference_collection".format( + logical_id=logical_id + ) + ) self._resource_preferences[logical_id] = DeploymentPreference.from_dict(logical_id, deployment_preference_dict) @@ -82,21 +86,23 @@ def enabled_logical_ids(self): def _codedeploy_application(self): codedeploy_application_resource = CodeDeployApplication(CODEDEPLOY_APPLICATION_LOGICAL_ID) - codedeploy_application_resource.ComputePlatform = 'Lambda' + codedeploy_application_resource.ComputePlatform = "Lambda" return codedeploy_application_resource def _codedeploy_iam_role(self): iam_role = IAMRole(CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID) iam_role.AssumeRolePolicyDocument = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': ['sts:AssumeRole'], - 'Effect': 'Allow', - 'Principal': {'Service': ['codedeploy.amazonaws.com']} - }] + "Version": "2012-10-17", + "Statement": [ + { + "Action": ["sts:AssumeRole"], + "Effect": "Allow", + "Principal": {"Service": ["codedeploy.amazonaws.com"]}, + } + ], } iam_role.ManagedPolicyArns = [ - ArnGenerator.generate_aws_managed_policy_arn('service-role/AWSCodeDeployRoleForLambda') + ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSCodeDeployRoleForLambda") ] return iam_role @@ -112,21 +118,22 @@ def deployment_group(self, function_logical_id): deployment_group = CodeDeployDeploymentGroup(self.deployment_group_logical_id(function_logical_id)) if deployment_preference.alarms is not None: - deployment_group.AlarmConfiguration = {'Enabled': True, - 'Alarms': [{'Name': alarm} for alarm in - deployment_preference.alarms]} - - deployment_group.ApplicationName = self.codedeploy_application.get_runtime_attr('name') - deployment_group.AutoRollbackConfiguration = {'Enabled': True, - 'Events': ['DEPLOYMENT_FAILURE', - 'DEPLOYMENT_STOP_ON_ALARM', - 'DEPLOYMENT_STOP_ON_REQUEST']} + deployment_group.AlarmConfiguration = { + "Enabled": True, + "Alarms": [{"Name": alarm} for alarm in deployment_preference.alarms], + } + + deployment_group.ApplicationName = self.codedeploy_application.get_runtime_attr("name") + deployment_group.AutoRollbackConfiguration = { + "Enabled": True, + "Events": ["DEPLOYMENT_FAILURE", "DEPLOYMENT_STOP_ON_ALARM", "DEPLOYMENT_STOP_ON_REQUEST"], + } - deployment_group.DeploymentConfigName = self._replace_deployment_types(copy.deepcopy( - deployment_preference.deployment_type)) + deployment_group.DeploymentConfigName = self._replace_deployment_types( + copy.deepcopy(deployment_preference.deployment_type) + ) - deployment_group.DeploymentStyle = {'DeploymentType': 'BLUE_GREEN', - 'DeploymentOption': 'WITH_TRAFFIC_CONTROL'} + deployment_group.DeploymentStyle = {"DeploymentType": "BLUE_GREEN", "DeploymentOption": "WITH_TRAFFIC_CONTROL"} deployment_group.ServiceRoleArn = self.codedeploy_iam_role.get_runtime_attr("arn") if deployment_preference.role: @@ -157,14 +164,14 @@ def update_policy(self, function_logical_id): deployment_preference = self.get(function_logical_id) return UpdatePolicy( - self.codedeploy_application.get_runtime_attr('name'), - self.deployment_group(function_logical_id).get_runtime_attr('name'), + self.codedeploy_application.get_runtime_attr("name"), + self.deployment_group(function_logical_id).get_runtime_attr("name"), deployment_preference.pre_traffic_hook, deployment_preference.post_traffic_hook, ) def deployment_group_logical_id(self, function_logical_id): - return function_logical_id + 'DeploymentGroup' + return function_logical_id + "DeploymentGroup" def __eq__(self, other): if isinstance(other, self.__class__): diff --git a/samtranslator/model/route53.py b/samtranslator/model/route53.py index 25fd264df4..dcb94818b7 100644 --- a/samtranslator/model/route53.py +++ b/samtranslator/model/route53.py @@ -3,8 +3,5 @@ class Route53RecordSetGroup(Resource): - resource_type = 'AWS::Route53::RecordSetGroup' - property_types = { - 'HostedZoneId': PropertyType(False, is_str()), - 'RecordSets': PropertyType(False, is_type(list)), - } + resource_type = "AWS::Route53::RecordSetGroup" + property_types = {"HostedZoneId": PropertyType(False, is_str()), "RecordSets": PropertyType(False, is_type(list))} diff --git a/samtranslator/model/s3.py b/samtranslator/model/s3.py index a09be671fa..6c60c26c78 100644 --- a/samtranslator/model/s3.py +++ b/samtranslator/model/s3.py @@ -4,27 +4,24 @@ class S3Bucket(Resource): - resource_type = 'AWS::S3::Bucket' + resource_type = "AWS::S3::Bucket" property_types = { - 'AccessControl': PropertyType(False, any_type()), - 'AccelerateConfiguration': PropertyType(False, any_type()), - 'AnalyticsConfigurations': PropertyType(False, any_type()), - 'BucketEncryption': PropertyType(False, any_type()), - 'BucketName': PropertyType(False, is_str()), - 'CorsConfiguration': PropertyType(False, any_type()), - 'InventoryConfigurations': PropertyType(False, any_type()), - 'LifecycleConfiguration': PropertyType(False, any_type()), - 'LoggingConfiguration': PropertyType(False, any_type()), - 'MetricsConfigurations': PropertyType(False, any_type()), - 'NotificationConfiguration': PropertyType(False, is_type(dict)), - 'PublicAccessBlockConfiguration': PropertyType(False, is_type(dict)), - 'ReplicationConfiguration': PropertyType(False, any_type()), - 'Tags': PropertyType(False, is_type(list)), - 'VersioningConfiguration': PropertyType(False, any_type()), - 'WebsiteConfiguration': PropertyType(False, any_type()), + "AccessControl": PropertyType(False, any_type()), + "AccelerateConfiguration": PropertyType(False, any_type()), + "AnalyticsConfigurations": PropertyType(False, any_type()), + "BucketEncryption": PropertyType(False, any_type()), + "BucketName": PropertyType(False, is_str()), + "CorsConfiguration": PropertyType(False, any_type()), + "InventoryConfigurations": PropertyType(False, any_type()), + "LifecycleConfiguration": PropertyType(False, any_type()), + "LoggingConfiguration": PropertyType(False, any_type()), + "MetricsConfigurations": PropertyType(False, any_type()), + "NotificationConfiguration": PropertyType(False, is_type(dict)), + "PublicAccessBlockConfiguration": PropertyType(False, is_type(dict)), + "ReplicationConfiguration": PropertyType(False, any_type()), + "Tags": PropertyType(False, is_type(list)), + "VersioningConfiguration": PropertyType(False, any_type()), + "WebsiteConfiguration": PropertyType(False, any_type()), } - runtime_attrs = { - "name": lambda self: ref(self.logical_id), - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + runtime_attrs = {"name": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn")} diff --git a/samtranslator/model/s3_utils/uri_parser.py b/samtranslator/model/s3_utils/uri_parser.py index 371cfb923f..e55a6a6820 100644 --- a/samtranslator/model/s3_utils/uri_parser.py +++ b/samtranslator/model/s3_utils/uri_parser.py @@ -15,13 +15,10 @@ def parse_s3_uri(uri): url = urlparse(uri) query = parse_qs(url.query) - if url.scheme == 's3' and url.netloc and url.path: - s3_pointer = { - 'Bucket': url.netloc, - 'Key': url.path.lstrip('/') - } - if 'versionId' in query and len(query['versionId']) == 1: - s3_pointer['Version'] = query['versionId'][0] + if url.scheme == "s3" and url.netloc and url.path: + s3_pointer = {"Bucket": url.netloc, "Key": url.path.lstrip("/")} + if "versionId" in query and len(query["versionId"]) == 1: + s3_pointer["Version"] = query["versionId"][0] return s3_pointer else: return None @@ -61,9 +58,9 @@ def construct_s3_location_object(location_uri, logical_id, property_name): if isinstance(location_uri, dict): if not location_uri.get("Bucket") or not location_uri.get("Key"): # location_uri is a dictionary but does not contain Bucket or Key property - raise InvalidResourceException(logical_id, - "'{}' requires Bucket and Key properties to be " - "specified".format(property_name)) + raise InvalidResourceException( + logical_id, "'{}' requires Bucket and Key properties to be " "specified".format(property_name) + ) s3_pointer = location_uri @@ -72,15 +69,14 @@ def construct_s3_location_object(location_uri, logical_id, property_name): s3_pointer = parse_s3_uri(location_uri) if s3_pointer is None: - raise InvalidResourceException(logical_id, - '\'{}\' is not a valid S3 Uri of the form ' - '"s3://bucket/key" with optional versionId query ' - 'parameter.'.format(property_name)) - - code = { - 'S3Bucket': s3_pointer['Bucket'], - 'S3Key': s3_pointer['Key'] - } - if 'Version' in s3_pointer: - code['S3ObjectVersion'] = s3_pointer['Version'] + raise InvalidResourceException( + logical_id, + "'{}' is not a valid S3 Uri of the form " + '"s3://bucket/key" with optional versionId query ' + "parameter.".format(property_name), + ) + + code = {"S3Bucket": s3_pointer["Bucket"], "S3Key": s3_pointer["Key"]} + if "Version" in s3_pointer: + code["S3ObjectVersion"] = s3_pointer["Version"] return code diff --git a/samtranslator/model/sam_resources.py b/samtranslator/model/sam_resources.py index 2203525b70..ae290e6475 100644 --- a/samtranslator/model/sam_resources.py +++ b/samtranslator/model/sam_resources.py @@ -9,23 +9,32 @@ from .api.http_api_generator import HttpApiGenerator from .s3_utils.uri_parser import construct_s3_location_object from .tags.resource_tagging import get_tag_list -from samtranslator.model import (PropertyType, SamResourceMacro, - ResourceTypeResolver) +from samtranslator.model import PropertyType, SamResourceMacro, ResourceTypeResolver from samtranslator.model.apigateway import ApiGatewayDeployment, ApiGatewayStage, ApiGatewayDomainName from samtranslator.model.apigatewayv2 import ApiGatewayV2Stage from samtranslator.model.cloudformation import NestedStack from samtranslator.model.dynamodb import DynamoDBTable -from samtranslator.model.exceptions import (InvalidEventException, - InvalidResourceException) +from samtranslator.model.exceptions import InvalidEventException, InvalidResourceException from samtranslator.model.function_policies import FunctionPolicies, PolicyTypes from samtranslator.model.iam import IAMRole, IAMRolePolicies -from samtranslator.model.lambda_ import (LambdaFunction, LambdaVersion, LambdaAlias, - LambdaLayerVersion, LambdaEventInvokeConfig) +from samtranslator.model.lambda_ import ( + LambdaFunction, + LambdaVersion, + LambdaAlias, + LambdaLayerVersion, + LambdaEventInvokeConfig, +) from samtranslator.model.types import dict_of, is_str, is_type, list_of, one_of, any_type from samtranslator.translator import logical_id_generator from samtranslator.translator.arn_generator import ArnGenerator -from samtranslator.model.intrinsics import (is_intrinsic_if, is_intrinsic_no_value, ref, - make_not_conditional, make_conditional, make_and_condition) +from samtranslator.model.intrinsics import ( + is_intrinsic_if, + is_intrinsic_no_value, + ref, + make_not_conditional, + make_conditional, + make_and_condition, +) from samtranslator.model.sqs import SQSQueue from samtranslator.model.sns import SNSTopic @@ -34,43 +43,45 @@ class SamFunction(SamResourceMacro): """SAM function macro. """ - resource_type = 'AWS::Serverless::Function' + resource_type = "AWS::Serverless::Function" property_types = { - 'FunctionName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Handler': PropertyType(True, is_str()), - 'Runtime': PropertyType(True, is_str()), - 'CodeUri': PropertyType(False, one_of(is_str(), is_type(dict))), - 'InlineCode': PropertyType(False, one_of(is_str(), is_type(dict))), - 'DeadLetterQueue': PropertyType(False, is_type(dict)), - 'Description': PropertyType(False, is_str()), - 'MemorySize': PropertyType(False, is_type(int)), - 'Timeout': PropertyType(False, is_type(int)), - 'VpcConfig': PropertyType(False, is_type(dict)), - 'Role': PropertyType(False, is_str()), - 'AssumeRolePolicyDocument': PropertyType(False, is_type(dict)), - 'Policies': PropertyType(False, one_of(is_str(), list_of(one_of(is_str(), is_type(dict), is_type(dict))))), - 'PermissionsBoundary': PropertyType(False, is_str()), - 'Environment': PropertyType(False, dict_of(is_str(), is_type(dict))), - 'Events': PropertyType(False, dict_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, is_type(dict)), - 'Tracing': PropertyType(False, one_of(is_type(dict), is_str())), - 'KmsKeyArn': PropertyType(False, one_of(is_type(dict), is_str())), - 'DeploymentPreference': PropertyType(False, is_type(dict)), - 'ReservedConcurrentExecutions': PropertyType(False, any_type()), - 'Layers': PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), - 'EventInvokeConfig': PropertyType(False, is_type(dict)), - + "FunctionName": PropertyType(False, one_of(is_str(), is_type(dict))), + "Handler": PropertyType(True, is_str()), + "Runtime": PropertyType(True, is_str()), + "CodeUri": PropertyType(False, one_of(is_str(), is_type(dict))), + "InlineCode": PropertyType(False, one_of(is_str(), is_type(dict))), + "DeadLetterQueue": PropertyType(False, is_type(dict)), + "Description": PropertyType(False, is_str()), + "MemorySize": PropertyType(False, is_type(int)), + "Timeout": PropertyType(False, is_type(int)), + "VpcConfig": PropertyType(False, is_type(dict)), + "Role": PropertyType(False, is_str()), + "AssumeRolePolicyDocument": PropertyType(False, is_type(dict)), + "Policies": PropertyType(False, one_of(is_str(), list_of(one_of(is_str(), is_type(dict), is_type(dict))))), + "PermissionsBoundary": PropertyType(False, is_str()), + "Environment": PropertyType(False, dict_of(is_str(), is_type(dict))), + "Events": PropertyType(False, dict_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, is_type(dict)), + "Tracing": PropertyType(False, one_of(is_type(dict), is_str())), + "KmsKeyArn": PropertyType(False, one_of(is_type(dict), is_str())), + "DeploymentPreference": PropertyType(False, is_type(dict)), + "ReservedConcurrentExecutions": PropertyType(False, any_type()), + "Layers": PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), + "EventInvokeConfig": PropertyType(False, is_type(dict)), # Intrinsic functions in value of Alias property are not supported, yet - 'AutoPublishAlias': PropertyType(False, one_of(is_str())), - 'VersionDescription': PropertyType(False, is_str()), - 'ProvisionedConcurrencyConfig': PropertyType(False, is_type(dict)), + "AutoPublishAlias": PropertyType(False, one_of(is_str())), + "VersionDescription": PropertyType(False, is_str()), + "ProvisionedConcurrencyConfig": PropertyType(False, is_type(dict)), } - event_resolver = ResourceTypeResolver(samtranslator.model.eventsources, samtranslator.model.eventsources.pull, - samtranslator.model.eventsources.push, - samtranslator.model.eventsources.cloudwatchlogs) + event_resolver = ResourceTypeResolver( + samtranslator.model.eventsources, + samtranslator.model.eventsources.pull, + samtranslator.model.eventsources.push, + samtranslator.model.eventsources.cloudwatchlogs, + ) # DeadLetterQueue - dead_letter_queue_policy_actions = {'SQS': 'sqs:SendMessage', 'SNS': 'sns:Publish'} + dead_letter_queue_policy_actions = {"SQS": "sqs:SendMessage", "SNS": "sns:Publish"} # # Conditions @@ -82,14 +93,12 @@ class SamFunction(SamResourceMacro): "Version": LambdaVersion.resource_type, # EventConfig auto created SQS and SNS "DestinationTopic": SNSTopic.resource_type, - "DestinationQueue": SQSQueue.resource_type + "DestinationQueue": SQSQueue.resource_type, } def resources_to_link(self, resources): try: - return { - 'event_resources': self._event_resources_to_link(resources) - } + return {"event_resources": self._event_resources_to_link(resources)} except InvalidEventException as e: raise InvalidResourceException(self.logical_id, e.message) @@ -114,8 +123,10 @@ def to_cloudformation(self, **kwargs): if self.ProvisionedConcurrencyConfig: if not self.AutoPublishAlias: - raise InvalidResourceException(self.logical_id, "To set ProvisionedConcurrencyConfig " - "AutoPublishALias must be defined on the function") + raise InvalidResourceException( + self.logical_id, + "To set ProvisionedConcurrencyConfig " "AutoPublishALias must be defined on the function", + ) lambda_alias = None alias_name = "" @@ -127,32 +138,34 @@ def to_cloudformation(self, **kwargs): resources.append(lambda_alias) if self.DeploymentPreference: - self._validate_deployment_preference_and_add_update_policy(kwargs.get('deployment_preference_collection', - None), - lambda_alias, intrinsics_resolver, - mappings_resolver) + self._validate_deployment_preference_and_add_update_policy( + kwargs.get("deployment_preference_collection", None), + lambda_alias, + intrinsics_resolver, + mappings_resolver, + ) event_invoke_policies = [] if self.EventInvokeConfig: function_name = lambda_function.logical_id - event_invoke_resources, event_invoke_policies = self._construct_event_invoke_config(function_name, - alias_name, - intrinsics_resolver, - conditions) + event_invoke_resources, event_invoke_policies = self._construct_event_invoke_config( + function_name, alias_name, intrinsics_resolver, conditions + ) resources.extend(event_invoke_resources) - managed_policy_map = kwargs.get('managed_policy_map', {}) + managed_policy_map = kwargs.get("managed_policy_map", {}) if not managed_policy_map: - raise Exception('Managed policy map is empty, but should not be.') + raise Exception("Managed policy map is empty, but should not be.") execution_role = None if lambda_function.Role is None: execution_role = self._construct_role(managed_policy_map, event_invoke_policies) - lambda_function.Role = execution_role.get_runtime_attr('arn') + lambda_function.Role = execution_role.get_runtime_attr("arn") resources.append(execution_role) try: - resources += self._generate_event_resources(lambda_function, execution_role, kwargs['event_resources'], - lambda_alias=lambda_alias) + resources += self._generate_event_resources( + lambda_function, execution_role, kwargs["event_resources"], lambda_alias=lambda_alias + ) except InvalidEventException as e: raise InvalidResourceException(self.logical_id, e.message) @@ -172,24 +185,24 @@ def _construct_event_invoke_config(self, function_name, lambda_alias, intrinsics lambda_event_invoke_config = LambdaEventInvokeConfig(logical_id=logical_id, attributes=self.resource_attributes) dest_config = {} - input_dest_config = resolved_event_invoke_config.get('DestinationConfig') - if input_dest_config and \ - input_dest_config.get('OnSuccess') is not None: - resource, on_success, policy = self._validate_and_inject_resource(input_dest_config.get('OnSuccess'), - "OnSuccess", logical_id, conditions) - dest_config['OnSuccess'] = on_success - self.EventInvokeConfig['DestinationConfig']['OnSuccess']['Destination'] = on_success.get('Destination') + input_dest_config = resolved_event_invoke_config.get("DestinationConfig") + if input_dest_config and input_dest_config.get("OnSuccess") is not None: + resource, on_success, policy = self._validate_and_inject_resource( + input_dest_config.get("OnSuccess"), "OnSuccess", logical_id, conditions + ) + dest_config["OnSuccess"] = on_success + self.EventInvokeConfig["DestinationConfig"]["OnSuccess"]["Destination"] = on_success.get("Destination") if resource is not None: resources.extend([resource]) if policy is not None: policy_document.append(policy) - if input_dest_config and \ - input_dest_config.get('OnFailure') is not None: - resource, on_failure, policy = self._validate_and_inject_resource(input_dest_config.get('OnFailure'), - "OnFailure", logical_id, conditions) - dest_config['OnFailure'] = on_failure - self.EventInvokeConfig['DestinationConfig']['OnFailure']['Destination'] = on_failure.get('Destination') + if input_dest_config and input_dest_config.get("OnFailure") is not None: + resource, on_failure, policy = self._validate_and_inject_resource( + input_dest_config.get("OnFailure"), "OnFailure", logical_id, conditions + ) + dest_config["OnFailure"] = on_failure + self.EventInvokeConfig["DestinationConfig"]["OnFailure"]["Destination"] = on_failure.get("Destination") if resource is not None: resources.extend([resource]) if policy is not None: @@ -199,11 +212,12 @@ def _construct_event_invoke_config(self, function_name, lambda_alias, intrinsics if lambda_alias: lambda_event_invoke_config.Qualifier = lambda_alias else: - lambda_event_invoke_config.Qualifier = '$LATEST' + lambda_event_invoke_config.Qualifier = "$LATEST" lambda_event_invoke_config.DestinationConfig = dest_config - lambda_event_invoke_config.MaximumEventAgeInSeconds = \ - resolved_event_invoke_config.get('MaximumEventAgeInSeconds') - lambda_event_invoke_config.MaximumRetryAttempts = resolved_event_invoke_config.get('MaximumRetryAttempts') + lambda_event_invoke_config.MaximumEventAgeInSeconds = resolved_event_invoke_config.get( + "MaximumEventAgeInSeconds" + ) + lambda_event_invoke_config.MaximumRetryAttempts = resolved_event_invoke_config.get("MaximumRetryAttempts") resources.extend([lambda_event_invoke_config]) return resources, policy_document @@ -215,47 +229,50 @@ def _validate_and_inject_resource(self, dest_config, event, logical_id, conditio ARN property, so to handle conditional ifs we have to inject if conditions in the auto created SQS/SNS resources as well as in the policy documents. """ - accepted_types_list = ['SQS', 'SNS', 'EventBridge', 'Lambda'] - auto_inject_list = ['SQS', 'SNS'] + accepted_types_list = ["SQS", "SNS", "EventBridge", "Lambda"] + auto_inject_list = ["SQS", "SNS"] resource = None policy = {} destination = {} - destination['Destination'] = dest_config.get('Destination') + destination["Destination"] = dest_config.get("Destination") resource_logical_id = logical_id + event - if dest_config.get('Type') is None or \ - dest_config.get('Type') not in accepted_types_list: - raise InvalidResourceException(self.logical_id, - "'Type: {}' must be one of {}" - .format(dest_config.get('Type'), accepted_types_list)) - - property_condition, dest_arn = self._get_or_make_condition(dest_config.get('Destination'), - logical_id, conditions) - if dest_config.get('Destination') is None or property_condition is not None: - combined_condition = self._make_and_conditions(self.get_passthrough_resource_attributes(), - property_condition, conditions) - if dest_config.get('Type') in auto_inject_list: - if dest_config.get('Type') == 'SQS': - resource = SQSQueue(resource_logical_id + 'Queue') - if dest_config.get('Type') == 'SNS': - resource = SNSTopic(resource_logical_id + 'Topic') + if dest_config.get("Type") is None or dest_config.get("Type") not in accepted_types_list: + raise InvalidResourceException( + self.logical_id, "'Type: {}' must be one of {}".format(dest_config.get("Type"), accepted_types_list) + ) + + property_condition, dest_arn = self._get_or_make_condition( + dest_config.get("Destination"), logical_id, conditions + ) + if dest_config.get("Destination") is None or property_condition is not None: + combined_condition = self._make_and_conditions( + self.get_passthrough_resource_attributes(), property_condition, conditions + ) + if dest_config.get("Type") in auto_inject_list: + if dest_config.get("Type") == "SQS": + resource = SQSQueue(resource_logical_id + "Queue") + if dest_config.get("Type") == "SNS": + resource = SNSTopic(resource_logical_id + "Topic") if combined_condition: - resource.set_resource_attribute('Condition', combined_condition) + resource.set_resource_attribute("Condition", combined_condition) if property_condition: - destination['Destination'] = make_conditional(property_condition, - resource.get_runtime_attr('arn'), - dest_arn) + destination["Destination"] = make_conditional( + property_condition, resource.get_runtime_attr("arn"), dest_arn + ) else: - destination['Destination'] = resource.get_runtime_attr('arn') - policy = self._add_event_invoke_managed_policy(dest_config, resource_logical_id, property_condition, - destination['Destination']) + destination["Destination"] = resource.get_runtime_attr("arn") + policy = self._add_event_invoke_managed_policy( + dest_config, resource_logical_id, property_condition, destination["Destination"] + ) else: - raise InvalidResourceException(self.logical_id, - "Destination is required if Type is not {}" - .format(auto_inject_list)) - if dest_config.get('Destination') is not None and property_condition is None: - policy = self._add_event_invoke_managed_policy(dest_config, resource_logical_id, - None, dest_config.get('Destination')) + raise InvalidResourceException( + self.logical_id, "Destination is required if Type is not {}".format(auto_inject_list) + ) + if dest_config.get("Destination") is not None and property_condition is None: + policy = self._add_event_invoke_managed_policy( + dest_config, resource_logical_id, None, dest_config.get("Destination") + ) return resource, destination, policy @@ -264,11 +281,12 @@ def _make_and_conditions(self, resource_condition, property_condition, condition return property_condition if property_condition is None: - return resource_condition['Condition'] + return resource_condition["Condition"] - and_condition = make_and_condition([resource_condition, {'Condition': property_condition}]) - condition_name = self._make_gen_condition_name(resource_condition.get('Condition') + 'AND' + property_condition, - self.logical_id) + and_condition = make_and_condition([resource_condition, {"Condition": property_condition}]) + condition_name = self._make_gen_condition_name( + resource_condition.get("Condition") + "AND" + property_condition, self.logical_id + ) conditions[condition_name] = and_condition return condition_name @@ -290,14 +308,14 @@ def _get_or_make_condition(self, destination, logical_id, conditions): if destination is None: return None, None if is_intrinsic_if(destination): - dest_list = destination.get('Fn::If') + dest_list = destination.get("Fn::If") if is_intrinsic_no_value(dest_list[1]) and is_intrinsic_no_value(dest_list[2]): return None, None if is_intrinsic_no_value(dest_list[1]): return dest_list[0], dest_list[2] if is_intrinsic_no_value(dest_list[2]): condition = dest_list[0] - not_condition = self._make_gen_condition_name('NOT' + condition, logical_id) + not_condition = self._make_gen_condition_name("NOT" + condition, logical_id) conditions[not_condition] = make_not_conditional(condition) return not_condition, dest_list[1] return None, None @@ -328,9 +346,9 @@ def _get_resolved_alias_name(self, property_name, original_alias_value, intrinsi if not isinstance(resolved_alias_name, string_types): # This is still a dictionary which means we are not able to completely resolve intrinsics - raise InvalidResourceException(self.logical_id, - "'{}' must be a string or a Ref to a template parameter" - .format(property_name)) + raise InvalidResourceException( + self.logical_id, "'{}' must be a string or a Ref to a template parameter".format(property_name) + ) return resolved_alias_name @@ -340,8 +358,9 @@ def _construct_lambda_function(self): :returns: a list containing the Lambda function and execution role resources :rtype: list """ - lambda_function = LambdaFunction(self.logical_id, depends_on=self.depends_on, - attributes=self.resource_attributes) + lambda_function = LambdaFunction( + self.logical_id, depends_on=self.depends_on, attributes=self.resource_attributes + ) if self.FunctionName: lambda_function.FunctionName = self.FunctionName @@ -364,25 +383,22 @@ def _construct_lambda_function(self): lambda_function.TracingConfig = {"Mode": self.Tracing} if self.DeadLetterQueue: - lambda_function.DeadLetterConfig = {"TargetArn": self.DeadLetterQueue['TargetArn']} + lambda_function.DeadLetterConfig = {"TargetArn": self.DeadLetterQueue["TargetArn"]} return lambda_function def _add_event_invoke_managed_policy(self, dest_config, logical_id, condition, dest_arn): policy = {} - if dest_config and dest_config.get('Type'): - if dest_config.get('Type') == 'SQS': - policy = IAMRolePolicies.sqs_send_message_role_policy(dest_arn, - logical_id) - if dest_config.get('Type') == 'SNS': - policy = IAMRolePolicies.sns_publish_role_policy(dest_arn, - logical_id) + if dest_config and dest_config.get("Type"): + if dest_config.get("Type") == "SQS": + policy = IAMRolePolicies.sqs_send_message_role_policy(dest_arn, logical_id) + if dest_config.get("Type") == "SNS": + policy = IAMRolePolicies.sns_publish_role_policy(dest_arn, logical_id) # Event Bridge and Lambda Arns are passthrough. - if dest_config.get('Type') == 'EventBridge': + if dest_config.get("Type") == "EventBridge": policy = IAMRolePolicies.event_bus_put_events_role_policy(dest_arn, logical_id) - if dest_config.get('Type') == 'Lambda': - policy = IAMRolePolicies.lambda_invoke_function_role_policy(dest_arn, - logical_id) + if dest_config.get("Type") == "Lambda": + policy = IAMRolePolicies.lambda_invoke_function_role_policy(dest_arn, logical_id) return policy def _construct_role(self, managed_policy_map, event_invoke_policies): @@ -391,30 +407,35 @@ def _construct_role(self, managed_policy_map, event_invoke_policies): :returns: the generated IAM Role :rtype: model.iam.IAMRole """ - execution_role = IAMRole(self.logical_id + 'Role', attributes=self.get_passthrough_resource_attributes()) + execution_role = IAMRole(self.logical_id + "Role", attributes=self.get_passthrough_resource_attributes()) if self.AssumeRolePolicyDocument is not None: execution_role.AssumeRolePolicyDocument = self.AssumeRolePolicyDocument else: execution_role.AssumeRolePolicyDocument = IAMRolePolicies.lambda_assume_role_policy() - managed_policy_arns = [ArnGenerator.generate_aws_managed_policy_arn('service-role/AWSLambdaBasicExecutionRole')] + managed_policy_arns = [ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaBasicExecutionRole")] if self.Tracing: - managed_policy_arns.append(ArnGenerator.generate_aws_managed_policy_arn('AWSXrayWriteOnlyAccess')) + managed_policy_arns.append(ArnGenerator.generate_aws_managed_policy_arn("AWSXrayWriteOnlyAccess")) if self.VpcConfig: managed_policy_arns.append( - ArnGenerator.generate_aws_managed_policy_arn('service-role/AWSLambdaVPCAccessExecutionRole') + ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaVPCAccessExecutionRole") ) - function_policies = FunctionPolicies({"Policies": self.Policies}, - # No support for policy templates in the "core" - policy_template_processor=None) + function_policies = FunctionPolicies( + {"Policies": self.Policies}, + # No support for policy templates in the "core" + policy_template_processor=None, + ) policy_documents = [] if self.DeadLetterQueue: - policy_documents.append(IAMRolePolicies.dead_letter_queue_policy( - self.dead_letter_queue_policy_actions[self.DeadLetterQueue['Type']], - self.DeadLetterQueue['TargetArn'])) + policy_documents.append( + IAMRolePolicies.dead_letter_queue_policy( + self.dead_letter_queue_policy_actions[self.DeadLetterQueue["Type"]], + self.DeadLetterQueue["TargetArn"], + ) + ) if self.EventInvokeConfig: if event_invoke_policies is not None: @@ -431,25 +452,27 @@ def _construct_role(self, managed_policy_map, event_invoke_policies): if not is_intrinsic_no_value(then_statement): then_statement = { - 'PolicyName': execution_role.logical_id + 'Policy' + str(index), - 'PolicyDocument': then_statement + "PolicyName": execution_role.logical_id + "Policy" + str(index), + "PolicyDocument": then_statement, } intrinsic_if["Fn::If"][1] = then_statement if not is_intrinsic_no_value(else_statement): else_statement = { - 'PolicyName': execution_role.logical_id + 'Policy' + str(index), - 'PolicyDocument': else_statement + "PolicyName": execution_role.logical_id + "Policy" + str(index), + "PolicyDocument": else_statement, } intrinsic_if["Fn::If"][2] = else_statement policy_documents.append(intrinsic_if) else: - policy_documents.append({ - 'PolicyName': execution_role.logical_id + 'Policy' + str(index), - 'PolicyDocument': policy_entry.data - }) + policy_documents.append( + { + "PolicyName": execution_role.logical_id + "Policy" + str(index), + "PolicyDocument": policy_entry.data, + } + ) elif policy_entry.type is PolicyTypes.MANAGED_POLICY: @@ -473,8 +496,8 @@ def _construct_role(self, managed_policy_map, event_invoke_policies): else: # Policy Templates are not supported here in the "core" raise InvalidResourceException( - self.logical_id, - "Policy at index {} in the 'Policies' property is not valid".format(index)) + self.logical_id, "Policy at index {} in the 'Policies' property is not valid".format(index) + ) execution_role.ManagedPolicyArns = list(managed_policy_arns) execution_role.Policies = policy_documents or None @@ -489,15 +512,17 @@ def _validate_dlq(self): """ # Validate required logical ids valid_dlq_types = str(list(self.dead_letter_queue_policy_actions.keys())) - if not self.DeadLetterQueue.get('Type') or not self.DeadLetterQueue.get('TargetArn'): - raise InvalidResourceException(self.logical_id, - "'DeadLetterQueue' requires Type and TargetArn properties to be specified" - .format(valid_dlq_types)) + if not self.DeadLetterQueue.get("Type") or not self.DeadLetterQueue.get("TargetArn"): + raise InvalidResourceException( + self.logical_id, + "'DeadLetterQueue' requires Type and TargetArn properties to be specified".format(valid_dlq_types), + ) # Validate required Types - if not self.DeadLetterQueue['Type'] in self.dead_letter_queue_policy_actions: - raise InvalidResourceException(self.logical_id, - "'DeadLetterQueue' requires Type of {}".format(valid_dlq_types)) + if not self.DeadLetterQueue["Type"] in self.dead_letter_queue_policy_actions: + raise InvalidResourceException( + self.logical_id, "'DeadLetterQueue' requires Type of {}".format(valid_dlq_types) + ) def _event_resources_to_link(self, resources): event_resources = {} @@ -505,7 +530,8 @@ def _event_resources_to_link(self, resources): for logical_id, event_dict in self.Events.items(): try: event_source = self.event_resolver.resolve_resource_type(event_dict).from_dict( - self.logical_id + logical_id, event_dict, logical_id) + self.logical_id + logical_id, event_dict, logical_id + ) except (TypeError, AttributeError) as e: raise InvalidEventException(logical_id, "{}".format(e)) event_resources[logical_id] = event_source.resources_to_link(resources) @@ -545,14 +571,15 @@ def _generate_event_resources(self, lambda_function, execution_role, event_resou for logical_id, event_dict in sorted(self.Events.items(), key=SamFunction.order_events): try: eventsource = self.event_resolver.resolve_resource_type(event_dict).from_dict( - lambda_function.logical_id + logical_id, event_dict, logical_id) + lambda_function.logical_id + logical_id, event_dict, logical_id + ) except TypeError as e: raise InvalidEventException(logical_id, "{}".format(e)) kwargs = { # When Alias is provided, connect all event sources to the alias and *not* the function - 'function': lambda_alias or lambda_function, - 'role': execution_role, + "function": lambda_alias or lambda_function, + "role": execution_role, } for name, resource in event_resources[logical_id].items(): @@ -563,11 +590,9 @@ def _generate_event_resources(self, lambda_function, execution_role, event_resou def _construct_code_dict(self): if self.InlineCode: - return { - "ZipFile": self.InlineCode - } + return {"ZipFile": self.InlineCode} elif self.CodeUri: - return construct_s3_location_object(self.CodeUri, self.logical_id, 'CodeUri') + return construct_s3_location_object(self.CodeUri, self.logical_id, "CodeUri") else: raise InvalidResourceException(self.logical_id, "Either 'InlineCode' or 'CodeUri' must be set") @@ -618,7 +643,7 @@ def _construct_version(self, function, intrinsics_resolver): attributes["DeletionPolicy"] = "Retain" lambda_version = LambdaVersion(logical_id=logical_id, attributes=attributes) - lambda_version.FunctionName = function.get_runtime_attr('name') + lambda_version.FunctionName = function.get_runtime_attr("name") lambda_version.Description = self.VersionDescription return lambda_version @@ -639,86 +664,87 @@ def _construct_alias(self, name, function, version): logical_id = "{id}Alias{suffix}".format(id=function.logical_id, suffix=name) alias = LambdaAlias(logical_id=logical_id, attributes=self.get_passthrough_resource_attributes()) alias.Name = name - alias.FunctionName = function.get_runtime_attr('name') + alias.FunctionName = function.get_runtime_attr("name") alias.FunctionVersion = version.get_runtime_attr("version") if self.ProvisionedConcurrencyConfig: alias.ProvisionedConcurrencyConfig = self.ProvisionedConcurrencyConfig return alias - def _validate_deployment_preference_and_add_update_policy(self, deployment_preference_collection, lambda_alias, - intrinsics_resolver, mappings_resolver): - if 'Enabled' in self.DeploymentPreference: + def _validate_deployment_preference_and_add_update_policy( + self, deployment_preference_collection, lambda_alias, intrinsics_resolver, mappings_resolver + ): + if "Enabled" in self.DeploymentPreference: # resolve intrinsics and mappings for Type - enabled = self.DeploymentPreference['Enabled'] + enabled = self.DeploymentPreference["Enabled"] enabled = intrinsics_resolver.resolve_parameter_refs(enabled) enabled = mappings_resolver.resolve_parameter_refs(enabled) - self.DeploymentPreference['Enabled'] = enabled + self.DeploymentPreference["Enabled"] = enabled - if 'Type' in self.DeploymentPreference: + if "Type" in self.DeploymentPreference: # resolve intrinsics and mappings for Type - preference_type = self.DeploymentPreference['Type'] + preference_type = self.DeploymentPreference["Type"] preference_type = intrinsics_resolver.resolve_parameter_refs(preference_type) preference_type = mappings_resolver.resolve_parameter_refs(preference_type) - self.DeploymentPreference['Type'] = preference_type + self.DeploymentPreference["Type"] = preference_type if deployment_preference_collection is None: - raise ValueError('deployment_preference_collection required for parsing the deployment preference') + raise ValueError("deployment_preference_collection required for parsing the deployment preference") deployment_preference_collection.add(self.logical_id, self.DeploymentPreference) if deployment_preference_collection.get(self.logical_id).enabled: if self.AutoPublishAlias is None: raise InvalidResourceException( - self.logical_id, - "'DeploymentPreference' requires AutoPublishAlias property to be specified") + self.logical_id, "'DeploymentPreference' requires AutoPublishAlias property to be specified" + ) if lambda_alias is None: - raise ValueError('lambda_alias expected for updating it with the appropriate update policy') + raise ValueError("lambda_alias expected for updating it with the appropriate update policy") - lambda_alias.set_resource_attribute("UpdatePolicy", - deployment_preference_collection.update_policy( - self.logical_id).to_dict()) + lambda_alias.set_resource_attribute( + "UpdatePolicy", deployment_preference_collection.update_policy(self.logical_id).to_dict() + ) class SamApi(SamResourceMacro): """SAM rest API macro. """ - resource_type = 'AWS::Serverless::Api' + + resource_type = "AWS::Serverless::Api" property_types = { # Internal property set only by Implicit API plugin. If set to True, the API Event Source code will inject # Lambda Integration URI to the Swagger. To preserve backwards compatibility, this must be set only for # Implicit APIs. For Explicit APIs, customer is expected to set integration URI themselves. # In the future, we might rename and expose this property to customers so they can have SAM manage Explicit APIs # Swagger. - '__MANAGE_SWAGGER': PropertyType(False, is_type(bool)), - - 'Name': PropertyType(False, one_of(is_str(), is_type(dict))), - 'StageName': PropertyType(True, one_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, is_type(dict)), - 'DefinitionBody': PropertyType(False, is_type(dict)), - 'DefinitionUri': PropertyType(False, one_of(is_str(), is_type(dict))), - 'CacheClusterEnabled': PropertyType(False, is_type(bool)), - 'CacheClusterSize': PropertyType(False, is_str()), - 'Variables': PropertyType(False, is_type(dict)), - 'EndpointConfiguration': PropertyType(False, is_str()), - 'MethodSettings': PropertyType(False, is_type(list)), - 'BinaryMediaTypes': PropertyType(False, is_type(list)), - 'MinimumCompressionSize': PropertyType(False, is_type(int)), - 'Cors': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Auth': PropertyType(False, is_type(dict)), - 'GatewayResponses': PropertyType(False, is_type(dict)), - 'AccessLogSetting': PropertyType(False, is_type(dict)), - 'CanarySetting': PropertyType(False, is_type(dict)), - 'TracingEnabled': PropertyType(False, is_type(bool)), - 'OpenApiVersion': PropertyType(False, is_str()), - 'Models': PropertyType(False, is_type(dict)), - 'Domain': PropertyType(False, is_type(dict)) + "__MANAGE_SWAGGER": PropertyType(False, is_type(bool)), + "Name": PropertyType(False, one_of(is_str(), is_type(dict))), + "StageName": PropertyType(True, one_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, is_type(dict)), + "DefinitionBody": PropertyType(False, is_type(dict)), + "DefinitionUri": PropertyType(False, one_of(is_str(), is_type(dict))), + "CacheClusterEnabled": PropertyType(False, is_type(bool)), + "CacheClusterSize": PropertyType(False, is_str()), + "Variables": PropertyType(False, is_type(dict)), + "EndpointConfiguration": PropertyType(False, is_str()), + "MethodSettings": PropertyType(False, is_type(list)), + "BinaryMediaTypes": PropertyType(False, is_type(list)), + "MinimumCompressionSize": PropertyType(False, is_type(int)), + "Cors": PropertyType(False, one_of(is_str(), is_type(dict))), + "Auth": PropertyType(False, is_type(dict)), + "GatewayResponses": PropertyType(False, is_type(dict)), + "AccessLogSetting": PropertyType(False, is_type(dict)), + "CanarySetting": PropertyType(False, is_type(dict)), + "TracingEnabled": PropertyType(False, is_type(bool)), + "OpenApiVersion": PropertyType(False, is_str()), + "Models": PropertyType(False, is_type(dict)), + "Domain": PropertyType(False, is_type(dict)), } referable_properties = { "Stage": ApiGatewayStage.resource_type, "Deployment": ApiGatewayDeployment.resource_type, - "DomainName": ApiGatewayDomainName.resource_type + "DomainName": ApiGatewayDomainName.resource_type, } def to_cloudformation(self, **kwargs): @@ -735,31 +761,33 @@ def to_cloudformation(self, **kwargs): self.BinaryMediaTypes = intrinsics_resolver.resolve_parameter_refs(self.BinaryMediaTypes) self.Domain = intrinsics_resolver.resolve_parameter_refs(self.Domain) - api_generator = ApiGenerator(self.logical_id, - self.CacheClusterEnabled, - self.CacheClusterSize, - self.Variables, - self.depends_on, - self.DefinitionBody, - self.DefinitionUri, - self.Name, - self.StageName, - tags=self.Tags, - endpoint_configuration=self.EndpointConfiguration, - method_settings=self.MethodSettings, - binary_media=self.BinaryMediaTypes, - minimum_compression_size=self.MinimumCompressionSize, - cors=self.Cors, - auth=self.Auth, - gateway_responses=self.GatewayResponses, - access_log_setting=self.AccessLogSetting, - canary_setting=self.CanarySetting, - tracing_enabled=self.TracingEnabled, - resource_attributes=self.resource_attributes, - passthrough_resource_attributes=self.get_passthrough_resource_attributes(), - open_api_version=self.OpenApiVersion, - models=self.Models, - domain=self.Domain) + api_generator = ApiGenerator( + self.logical_id, + self.CacheClusterEnabled, + self.CacheClusterSize, + self.Variables, + self.depends_on, + self.DefinitionBody, + self.DefinitionUri, + self.Name, + self.StageName, + tags=self.Tags, + endpoint_configuration=self.EndpointConfiguration, + method_settings=self.MethodSettings, + binary_media=self.BinaryMediaTypes, + minimum_compression_size=self.MinimumCompressionSize, + cors=self.Cors, + auth=self.Auth, + gateway_responses=self.GatewayResponses, + access_log_setting=self.AccessLogSetting, + canary_setting=self.CanarySetting, + tracing_enabled=self.TracingEnabled, + resource_attributes=self.resource_attributes, + passthrough_resource_attributes=self.get_passthrough_resource_attributes(), + open_api_version=self.OpenApiVersion, + models=self.Models, + domain=self.Domain, + ) rest_api, deployment, stage, permissions, domain, basepath_mapping, route53 = api_generator.to_cloudformation() @@ -777,28 +805,26 @@ def to_cloudformation(self, **kwargs): class SamHttpApi(SamResourceMacro): """SAM rest API macro. """ - resource_type = 'AWS::Serverless::HttpApi' + + resource_type = "AWS::Serverless::HttpApi" property_types = { # Internal property set only by Implicit HTTP API plugin. If set to True, the API Event Source code will # inject Lambda Integration URI to the OpenAPI. To preserve backwards compatibility, this must be set only for # Implicit APIs. For Explicit APIs, this is managed by the DefaultDefinitionBody Plugin. # In the future, we might rename and expose this property to customers so they can have SAM manage Explicit APIs # Swagger. - '__MANAGE_SWAGGER': PropertyType(False, is_type(bool)), - - 'StageName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, is_type(dict)), - 'DefinitionBody': PropertyType(False, is_type(dict)), - 'DefinitionUri': PropertyType(False, one_of(is_str(), is_type(dict))), - 'StageVariables': PropertyType(False, is_type(dict)), - 'Cors': PropertyType(False, one_of(is_str(), is_type(dict))), - 'AccessLogSettings': PropertyType(False, is_type(dict)), - 'Auth': PropertyType(False, is_type(dict)) + "__MANAGE_SWAGGER": PropertyType(False, is_type(bool)), + "StageName": PropertyType(False, one_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, is_type(dict)), + "DefinitionBody": PropertyType(False, is_type(dict)), + "DefinitionUri": PropertyType(False, one_of(is_str(), is_type(dict))), + "StageVariables": PropertyType(False, is_type(dict)), + "Cors": PropertyType(False, one_of(is_str(), is_type(dict))), + "AccessLogSettings": PropertyType(False, is_type(dict)), + "Auth": PropertyType(False, is_type(dict)), } - referable_properties = { - "Stage": ApiGatewayV2Stage.resource_type, - } + referable_properties = {"Stage": ApiGatewayV2Stage.resource_type} def to_cloudformation(self, **kwargs): """Returns the API Gateway RestApi, Deployment, and Stage to which this SAM Api corresponds. @@ -810,17 +836,19 @@ def to_cloudformation(self, **kwargs): """ resources = [] - api_generator = HttpApiGenerator(self.logical_id, - self.StageVariables, - self.depends_on, - self.DefinitionBody, - self.DefinitionUri, - self.StageName, - tags=self.Tags, - auth=self.Auth, - access_log_settings=self.AccessLogSettings, - resource_attributes=self.resource_attributes, - passthrough_resource_attributes=self.get_passthrough_resource_attributes()) + api_generator = HttpApiGenerator( + self.logical_id, + self.StageVariables, + self.depends_on, + self.DefinitionBody, + self.DefinitionUri, + self.StageName, + tags=self.Tags, + auth=self.Auth, + access_log_settings=self.AccessLogSettings, + resource_attributes=self.resource_attributes, + passthrough_resource_attributes=self.get_passthrough_resource_attributes(), + ) http_api, stage = api_generator.to_cloudformation() @@ -836,19 +864,16 @@ def to_cloudformation(self, **kwargs): class SamSimpleTable(SamResourceMacro): """SAM simple table macro. """ - resource_type = 'AWS::Serverless::SimpleTable' + + resource_type = "AWS::Serverless::SimpleTable" property_types = { - 'PrimaryKey': PropertyType(False, dict_of(is_str(), is_str())), - 'ProvisionedThroughput': PropertyType(False, dict_of(is_str(), one_of(is_type(int), is_type(dict)))), - 'TableName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Tags': PropertyType(False, is_type(dict)), - 'SSESpecification': PropertyType(False, is_type(dict)) - } - attribute_type_conversions = { - 'String': 'S', - 'Number': 'N', - 'Binary': 'B' + "PrimaryKey": PropertyType(False, dict_of(is_str(), is_str())), + "ProvisionedThroughput": PropertyType(False, dict_of(is_str(), one_of(is_type(int), is_type(dict)))), + "TableName": PropertyType(False, one_of(is_str(), is_type(dict))), + "Tags": PropertyType(False, is_type(dict)), + "SSESpecification": PropertyType(False, is_type(dict)), } + attribute_type_conversions = {"String": "S", "Number": "N", "Binary": "B"} def to_cloudformation(self, **kwargs): dynamodb_resources = self._construct_dynamodb_table() @@ -859,29 +884,25 @@ def _construct_dynamodb_table(self): dynamodb_table = DynamoDBTable(self.logical_id, depends_on=self.depends_on, attributes=self.resource_attributes) if self.PrimaryKey: - if 'Name' not in self.PrimaryKey or 'Type' not in self.PrimaryKey: + if "Name" not in self.PrimaryKey or "Type" not in self.PrimaryKey: raise InvalidResourceException( - self.logical_id, - '\'PrimaryKey\' is missing required Property \'Name\' or \'Type\'.' + self.logical_id, "'PrimaryKey' is missing required Property 'Name' or 'Type'." ) primary_key = { - 'AttributeName': self.PrimaryKey['Name'], - 'AttributeType': self._convert_attribute_type(self.PrimaryKey['Type']) + "AttributeName": self.PrimaryKey["Name"], + "AttributeType": self._convert_attribute_type(self.PrimaryKey["Type"]), } else: - primary_key = {'AttributeName': 'id', 'AttributeType': 'S'} + primary_key = {"AttributeName": "id", "AttributeType": "S"} dynamodb_table.AttributeDefinitions = [primary_key] - dynamodb_table.KeySchema = [{ - 'AttributeName': primary_key['AttributeName'], - 'KeyType': 'HASH' - }] + dynamodb_table.KeySchema = [{"AttributeName": primary_key["AttributeName"], "KeyType": "HASH"}] if self.ProvisionedThroughput: dynamodb_table.ProvisionedThroughput = self.ProvisionedThroughput else: - dynamodb_table.BillingMode = 'PAY_PER_REQUEST' + dynamodb_table.BillingMode = "PAY_PER_REQUEST" if self.SSESpecification: dynamodb_table.SSESpecification = self.SSESpecification @@ -897,26 +918,26 @@ def _construct_dynamodb_table(self): def _convert_attribute_type(self, attribute_type): if attribute_type in self.attribute_type_conversions: return self.attribute_type_conversions[attribute_type] - raise InvalidResourceException(self.logical_id, 'Invalid \'Type\' "{actual}".'.format(actual=attribute_type)) + raise InvalidResourceException(self.logical_id, "Invalid 'Type' \"{actual}\".".format(actual=attribute_type)) class SamApplication(SamResourceMacro): """SAM application macro. """ - APPLICATION_ID_KEY = 'ApplicationId' - SEMANTIC_VERSION_KEY = 'SemanticVersion' + APPLICATION_ID_KEY = "ApplicationId" + SEMANTIC_VERSION_KEY = "SemanticVersion" - resource_type = 'AWS::Serverless::Application' + resource_type = "AWS::Serverless::Application" # The plugin will always insert the TemplateUrl parameter property_types = { - 'Location': PropertyType(True, one_of(is_str(), is_type(dict))), - 'TemplateUrl': PropertyType(False, is_str()), - 'Parameters': PropertyType(False, is_type(dict)), - 'NotificationARNs': PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), - 'Tags': PropertyType(False, is_type(dict)), - 'TimeoutInMinutes': PropertyType(False, is_type(int)) + "Location": PropertyType(True, one_of(is_str(), is_type(dict))), + "TemplateUrl": PropertyType(False, is_str()), + "Parameters": PropertyType(False, is_type(dict)), + "NotificationARNs": PropertyType(False, list_of(one_of(is_str(), is_type(dict)))), + "Tags": PropertyType(False, is_type(dict)), + "TimeoutInMinutes": PropertyType(False, is_type(int)), } def to_cloudformation(self, **kwargs): @@ -928,8 +949,9 @@ def to_cloudformation(self, **kwargs): def _construct_nested_stack(self): """Constructs a AWS::CloudFormation::Stack resource """ - nested_stack = NestedStack(self.logical_id, depends_on=self.depends_on, - attributes=self.get_passthrough_resource_attributes()) + nested_stack = NestedStack( + self.logical_id, depends_on=self.depends_on, attributes=self.get_passthrough_resource_attributes() + ) nested_stack.Parameters = self.Parameters nested_stack.NotificationARNs = self.NotificationARNs application_tags = self._get_application_tags() @@ -944,11 +966,12 @@ def _get_application_tags(self): """ application_tags = {} if isinstance(self.Location, dict): - if (self.APPLICATION_ID_KEY in self.Location.keys() and - self.Location[self.APPLICATION_ID_KEY] is not None): + if self.APPLICATION_ID_KEY in self.Location.keys() and self.Location[self.APPLICATION_ID_KEY] is not None: application_tags[self._SAR_APP_KEY] = self.Location[self.APPLICATION_ID_KEY] - if (self.SEMANTIC_VERSION_KEY in self.Location.keys() and - self.Location[self.SEMANTIC_VERSION_KEY] is not None): + if ( + self.SEMANTIC_VERSION_KEY in self.Location.keys() + and self.Location[self.SEMANTIC_VERSION_KEY] is not None + ): application_tags[self._SAR_SEMVER_KEY] = self.Location[self.SEMANTIC_VERSION_KEY] return application_tags @@ -956,18 +979,19 @@ def _get_application_tags(self): class SamLayerVersion(SamResourceMacro): """ SAM Layer macro """ - resource_type = 'AWS::Serverless::LayerVersion' + + resource_type = "AWS::Serverless::LayerVersion" property_types = { - 'LayerName': PropertyType(False, one_of(is_str(), is_type(dict))), - 'Description': PropertyType(False, is_str()), - 'ContentUri': PropertyType(True, one_of(is_str(), is_type(dict))), - 'CompatibleRuntimes': PropertyType(False, list_of(is_str())), - 'LicenseInfo': PropertyType(False, is_str()), - 'RetentionPolicy': PropertyType(False, is_str()) + "LayerName": PropertyType(False, one_of(is_str(), is_type(dict))), + "Description": PropertyType(False, is_str()), + "ContentUri": PropertyType(True, one_of(is_str(), is_type(dict))), + "CompatibleRuntimes": PropertyType(False, list_of(is_str())), + "LicenseInfo": PropertyType(False, is_str()), + "RetentionPolicy": PropertyType(False, is_str()), } - RETAIN = 'Retain' - DELETE = 'Delete' + RETAIN = "Retain" + DELETE = "Delete" retention_policy_options = [RETAIN.lower(), DELETE.lower()] def to_cloudformation(self, **kwargs): @@ -993,18 +1017,19 @@ def _construct_lambda_layer(self, intrinsics_resolver): :rtype: list """ # Resolve intrinsics if applicable: - self.LayerName = self._resolve_string_parameter(intrinsics_resolver, self.LayerName, 'LayerName') - self.LicenseInfo = self._resolve_string_parameter(intrinsics_resolver, self.LicenseInfo, 'LicenseInfo') - self.Description = self._resolve_string_parameter(intrinsics_resolver, self.Description, 'Description') - self.RetentionPolicy = self._resolve_string_parameter(intrinsics_resolver, self.RetentionPolicy, - 'RetentionPolicy') + self.LayerName = self._resolve_string_parameter(intrinsics_resolver, self.LayerName, "LayerName") + self.LicenseInfo = self._resolve_string_parameter(intrinsics_resolver, self.LicenseInfo, "LicenseInfo") + self.Description = self._resolve_string_parameter(intrinsics_resolver, self.Description, "Description") + self.RetentionPolicy = self._resolve_string_parameter( + intrinsics_resolver, self.RetentionPolicy, "RetentionPolicy" + ) retention_policy_value = self._get_retention_policy_value() attributes = self.get_passthrough_resource_attributes() if attributes is None: attributes = {} - attributes['DeletionPolicy'] = retention_policy_value + attributes["DeletionPolicy"] = retention_policy_value old_logical_id = self.logical_id new_logical_id = logical_id_generator.LogicalIdGenerator(old_logical_id, self.to_dict()).gen() @@ -1026,7 +1051,7 @@ def _construct_lambda_layer(self, intrinsics_resolver): lambda_layer.LayerName = self.LayerName lambda_layer.Description = self.Description - lambda_layer.Content = construct_s3_location_object(self.ContentUri, self.logical_id, 'ContentUri') + lambda_layer.Content = construct_s3_location_object(self.ContentUri, self.logical_id, "ContentUri") lambda_layer.CompatibleRuntimes = self.CompatibleRuntimes lambda_layer.LicenseInfo = self.LicenseInfo @@ -1044,6 +1069,7 @@ def _get_retention_policy_value(self): elif self.RetentionPolicy.lower() == self.DELETE.lower(): return self.DELETE elif self.RetentionPolicy.lower() not in self.retention_policy_options: - raise InvalidResourceException(self.logical_id, - "'{}' must be one of the following options: {}." - .format('RetentionPolicy', [self.RETAIN, self.DELETE])) + raise InvalidResourceException( + self.logical_id, + "'{}' must be one of the following options: {}.".format("RetentionPolicy", [self.RETAIN, self.DELETE]), + ) diff --git a/samtranslator/model/sns.py b/samtranslator/model/sns.py index a0d0d95ef4..312f57296a 100644 --- a/samtranslator/model/sns.py +++ b/samtranslator/model/sns.py @@ -4,21 +4,17 @@ class SNSSubscription(Resource): - resource_type = 'AWS::SNS::Subscription' + resource_type = "AWS::SNS::Subscription" property_types = { - 'Endpoint': PropertyType(True, is_str()), - 'Protocol': PropertyType(True, is_str()), - 'TopicArn': PropertyType(True, is_str()), - 'Region': PropertyType(False, is_str()), - 'FilterPolicy': PropertyType(False, is_type(dict)) + "Endpoint": PropertyType(True, is_str()), + "Protocol": PropertyType(True, is_str()), + "TopicArn": PropertyType(True, is_str()), + "Region": PropertyType(False, is_str()), + "FilterPolicy": PropertyType(False, is_type(dict)), } class SNSTopic(Resource): - resource_type = 'AWS::SNS::Topic' - property_types = { - 'TopicName': PropertyType(False, is_str()) - } - runtime_attrs = { - "arn": lambda self: ref(self.logical_id) - } + resource_type = "AWS::SNS::Topic" + property_types = {"TopicName": PropertyType(False, is_str())} + runtime_attrs = {"arn": lambda self: ref(self.logical_id)} diff --git a/samtranslator/model/sqs.py b/samtranslator/model/sqs.py index a262d6c2ab..b2691d309f 100644 --- a/samtranslator/model/sqs.py +++ b/samtranslator/model/sqs.py @@ -4,9 +4,8 @@ class SQSQueue(Resource): - resource_type = 'AWS::SQS::Queue' - property_types = { - } + resource_type = "AWS::SQS::Queue" + property_types = {} runtime_attrs = { "queue_url": lambda self: ref(self.logical_id), "arn": lambda self: fnGetAtt(self.logical_id, "Arn"), @@ -14,31 +13,24 @@ class SQSQueue(Resource): class SQSQueuePolicy(Resource): - resource_type = 'AWS::SQS::QueuePolicy' - property_types = { - 'PolicyDocument': PropertyType(True, is_type(dict)), - 'Queues': PropertyType(True, list_of(str)), - } - runtime_attrs = { - "arn": lambda self: fnGetAtt(self.logical_id, "Arn") - } + resource_type = "AWS::SQS::QueuePolicy" + property_types = {"PolicyDocument": PropertyType(True, is_type(dict)), "Queues": PropertyType(True, list_of(str))} + runtime_attrs = {"arn": lambda self: fnGetAtt(self.logical_id, "Arn")} class SQSQueuePolicies: @classmethod def sns_topic_send_message_role_policy(cls, topic_arn, queue_arn): document = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': 'sqs:SendMessage', - 'Effect': 'Allow', - 'Principal': '*', - 'Resource': queue_arn, - 'Condition': { - 'ArnEquals': { - 'aws:SourceArn': topic_arn - } + "Version": "2012-10-17", + "Statement": [ + { + "Action": "sqs:SendMessage", + "Effect": "Allow", + "Principal": "*", + "Resource": queue_arn, + "Condition": {"ArnEquals": {"aws:SourceArn": topic_arn}}, } - }] + ], } return document diff --git a/samtranslator/model/tags/resource_tagging.py b/samtranslator/model/tags/resource_tagging.py index 16db78941a..79ad3bf30a 100644 --- a/samtranslator/model/tags/resource_tagging.py +++ b/samtranslator/model/tags/resource_tagging.py @@ -1,4 +1,3 @@ - # Constants for Tagging _KEY = "Key" _VALUE = "Value" diff --git a/samtranslator/model/types.py b/samtranslator/model/types.py index d92ec1ab7e..36b44257fe 100644 --- a/samtranslator/model/types.py +++ b/samtranslator/model/types.py @@ -19,13 +19,18 @@ def is_type(valid_type): :returns: a function which returns True its input is an instance of valid_type, and raises TypeError otherwise :rtype: callable """ + def validate(value, should_raise=True): if not isinstance(value, valid_type): if should_raise: - raise TypeError("Expected value of type {expected}, actual value was of type {actual}.".format( - expected=valid_type, actual=type(value))) + raise TypeError( + "Expected value of type {expected}, actual value was of type {actual}.".format( + expected=valid_type, actual=type(value) + ) + ) return False return True + return validate @@ -37,6 +42,7 @@ def list_of(validate_item): :returns: a function which returns True its input is an list of valid items, and raises TypeError otherwise :rtype: callable """ + def validate(value, should_raise=True): validate_type = is_type(list) if not validate_type(value, should_raise=should_raise): @@ -51,6 +57,7 @@ def validate(value, should_raise=True): raise return False return True + return validate @@ -63,6 +70,7 @@ def dict_of(validate_key, validate_item): :returns: a function which returns True its input is an dict of valid items, and raises TypeError otherwise :rtype: callable """ + def validate(value, should_raise=True): validate_type = is_type(dict) if not validate_type(value, should_raise=should_raise): @@ -85,6 +93,7 @@ def validate(value, should_raise=True): raise return False return True + return validate @@ -96,6 +105,7 @@ def one_of(*validators): otherwise :rtype: callable """ + def validate(value, should_raise=True): if any(validate(value, should_raise=False) for validate in validators): return True @@ -103,6 +113,7 @@ def validate(value, should_raise=True): if should_raise: raise TypeError("value did not match any allowable type") return False + return validate diff --git a/samtranslator/model/update_policy.py b/samtranslator/model/update_policy.py index 30414e0f63..45a1ca3dc9 100644 --- a/samtranslator/model/update_policy.py +++ b/samtranslator/model/update_policy.py @@ -2,9 +2,10 @@ from samtranslator.model.intrinsics import ref -CodeDeployLambdaAliasUpdate = namedtuple('CodeDeployLambdaAliasUpdate', - ['ApplicationName', 'DeploymentGroupName', 'BeforeAllowTrafficHook', - 'AfterAllowTrafficHook']) +CodeDeployLambdaAliasUpdate = namedtuple( + "CodeDeployLambdaAliasUpdate", + ["ApplicationName", "DeploymentGroupName", "BeforeAllowTrafficHook", "AfterAllowTrafficHook"], +) """ This class is a model for the update policy which becomes present on any function alias for which there is an enabled @@ -25,6 +26,7 @@ def to_dict(self): :return: a dict that can be used as part of a cloudformation template """ dict_with_nones = self._asdict() - codedeploy_lambda_alias_update_dict = dict((k, v) for k, v in dict_with_nones.items() - if v != ref(None) and v is not None) - return {'CodeDeployLambdaAliasUpdate': codedeploy_lambda_alias_update_dict} + codedeploy_lambda_alias_update_dict = dict( + (k, v) for k, v in dict_with_nones.items() if v != ref(None) and v is not None + ) + return {"CodeDeployLambdaAliasUpdate": codedeploy_lambda_alias_update_dict} diff --git a/samtranslator/open_api/open_api.py b/samtranslator/open_api/open_api.py index 5112a44c9b..2d8ff46b4e 100644 --- a/samtranslator/open_api/open_api.py +++ b/samtranslator/open_api/open_api.py @@ -15,9 +15,9 @@ class OpenApiEditor(object): empty skeleton. """ - _X_APIGW_INTEGRATION = 'x-amazon-apigateway-integration' + _X_APIGW_INTEGRATION = "x-amazon-apigateway-integration" _CONDITIONAL_IF = "Fn::If" - _X_ANY_METHOD = 'x-amazon-apigateway-any-method' + _X_ANY_METHOD = "x-amazon-apigateway-any-method" _ALL_HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"] _DEFAULT_PATH = "$default" @@ -30,13 +30,15 @@ def __init__(self, doc): :raises ValueError: If the input Swagger document does not meet the basic Swagger requirements. """ if not OpenApiEditor.is_valid(doc): - raise ValueError("Invalid OpenApi document. " - "Invalid values or missing keys for 'openapi' or 'paths' in 'DefinitionBody'.") + raise ValueError( + "Invalid OpenApi document. " + "Invalid values or missing keys for 'openapi' or 'paths' in 'DefinitionBody'." + ) self._doc = copy.deepcopy(doc) self.paths = self._doc["paths"] self.security_schemes = self._doc.get("components", {}).get("securitySchemes", {}) - self.definitions = self._doc.get('definitions', {}) + self.definitions = self._doc.get("definitions", {}) def get_path(self, path): """ @@ -141,9 +143,11 @@ def has_integration(self, path, method): method = self._normalize_method_name(method) path_dict = self.get_path(path) - return self.has_path(path, method) and \ - isinstance(path_dict[method], dict) and \ - self.method_has_integration(path_dict[method]) # Integration present and non-empty + return ( + self.has_path(path, method) + and isinstance(path_dict[method], dict) + and self.method_has_integration(path_dict[method]) + ) # Integration present and non-empty def add_path(self, path, method=None): """ @@ -160,16 +164,21 @@ def add_path(self, path, method=None): if not isinstance(path_dict, dict): # Either customers has provided us an invalid Swagger, or this class has messed it somehow raise InvalidDocumentException( - [InvalidTemplateException("Value of '{}' path must be a dictionary according to Swagger spec." - .format(path))]) + [ + InvalidTemplateException( + "Value of '{}' path must be a dictionary according to Swagger spec.".format(path) + ) + ] + ) if self._CONDITIONAL_IF in path_dict: path_dict = path_dict[self._CONDITIONAL_IF][1] path_dict.setdefault(method, {}) - def add_lambda_integration(self, path, method, integration_uri, - method_auth_config=None, api_auth_config=None, condition=None): + def add_lambda_integration( + self, path, method, integration_uri, method_auth_config=None, api_auth_config=None, condition=None + ): """ Adds aws_proxy APIGW integration to the given path+method. @@ -192,17 +201,17 @@ def add_lambda_integration(self, path, method, integration_uri, path_dict = self.get_path(path) path_dict[method][self._X_APIGW_INTEGRATION] = { - 'type': 'aws_proxy', - 'httpMethod': 'POST', - 'payloadFormatVersion': '1.0', - 'uri': integration_uri + "type": "aws_proxy", + "httpMethod": "POST", + "payloadFormatVersion": "1.0", + "uri": integration_uri, } if path == self._DEFAULT_PATH and method == self._X_ANY_METHOD: path_dict[method]["isDefaultRoute"] = True # If 'responses' key is *not* present, add it with an empty dict as value - path_dict[method].setdefault('responses', {}) + path_dict[method].setdefault("responses", {}) # If a condition is present, wrap all method contents up into the condition if condition: @@ -262,21 +271,22 @@ def set_path_default_authorizer(self, path, default_authorizer, authorizers, api # If no integration given, then we don't need to process this definition (could be AWS::NoValue) if not self.method_definition_has_integration(method_definition): continue - existing_security = method_definition.get('security', []) + existing_security = method_definition.get("security", []) if existing_security: return authorizer_list = [] if authorizers: authorizer_list.extend(authorizers.keys()) security_dict = dict() - security_dict[default_authorizer] = self._get_authorization_scopes(api_authorizers, - default_authorizer) + security_dict[default_authorizer] = self._get_authorization_scopes( + api_authorizers, default_authorizer + ) authorizer_security = [security_dict] security = authorizer_security if security: - method_definition['security'] = security + method_definition["security"] = security def add_auth_to_method(self, path, method_name, auth, api): """ @@ -315,12 +325,12 @@ def _set_method_authorizer(self, path, method_name, authorizer_name, authorizers if not self.method_definition_has_integration(method_definition): continue - existing_security = method_definition.get('security', []) + existing_security = method_definition.get("security", []) security_dict = dict() security_dict[authorizer_name] = [] - if authorizer_name != 'NONE': + if authorizer_name != "NONE": method_authorization_scopes = authorizers[authorizer_name].get("AuthorizationScopes") if authorization_scopes: method_authorization_scopes = authorization_scopes @@ -332,7 +342,7 @@ def _set_method_authorizer(self, path, method_name, authorizer_name, authorizers # This assumes there are no authorizers already configured in the existing security block security = existing_security + authorizer_security if security: - method_definition['security'] = security + method_definition["security"] = security @property def openapi(self): @@ -360,10 +370,11 @@ def is_valid(data): :return: True, if data is valid OpenApi """ - if bool(data) and isinstance(data, dict) and isinstance(data.get('paths'), dict): + if bool(data) and isinstance(data, dict) and isinstance(data.get("paths"), dict): if bool(data.get("openapi")): return OpenApiEditor.safe_compare_regex_with_string( - OpenApiEditor.get_openapi_version_3_regex(), data["openapi"]) + OpenApiEditor.get_openapi_version_3_regex(), data["openapi"] + ) return False @staticmethod @@ -373,15 +384,7 @@ def gen_skeleton(): :return dict: Dictionary of a skeleton swagger document """ - return { - 'openapi': '3.0.1', - 'info': { - 'version': '1.0', - 'title': ref('AWS::StackName') - }, - 'paths': { - } - } + return {"openapi": "3.0.1", "info": {"version": "1.0", "title": ref("AWS::StackName")}, "paths": {}} @staticmethod def _get_authorization_scopes(authorizers, default_authorizer): @@ -391,8 +394,10 @@ def _get_authorization_scopes(authorizers, default_authorizer): :param default_authorizer: name of the default authorizer """ if authorizers is not None: - if authorizers[default_authorizer] \ - and authorizers[default_authorizer].get("AuthorizationScopes") is not None: + if ( + authorizers[default_authorizer] + and authorizers[default_authorizer].get("AuthorizationScopes") is not None + ): return authorizers[default_authorizer].get("AuthorizationScopes") return [] @@ -411,7 +416,7 @@ def _normalize_method_name(method): return method method = method.lower() - if method == 'any': + if method == "any": return OpenApiEditor._X_ANY_METHOD else: return method @@ -427,4 +432,4 @@ def safe_compare_regex_with_string(regex, data): @staticmethod def get_path_without_trailing_slash(path): - return re.sub(r'{([a-zA-Z0-9._-]+|proxy\+)}', '*', path) + return re.sub(r"{([a-zA-Z0-9._-]+|proxy\+)}", "*", path) diff --git a/samtranslator/parser/parser.py b/samtranslator/parser/parser.py index 2672ae6457..6f13766207 100644 --- a/samtranslator/parser/parser.py +++ b/samtranslator/parser/parser.py @@ -22,17 +22,21 @@ def _validate(self, sam_template, parameter_values): if parameter_values is None: raise ValueError("`parameter_values` argument is required") - if ("Resources" not in sam_template or not isinstance(sam_template["Resources"], dict) or not - sam_template["Resources"]): + if ( + "Resources" not in sam_template + or not isinstance(sam_template["Resources"], dict) + or not sam_template["Resources"] + ): + raise InvalidDocumentException([InvalidTemplateException("'Resources' section is required")]) + + if not all(isinstance(sam_resource, dict) for sam_resource in sam_template["Resources"].values()): raise InvalidDocumentException( - [InvalidTemplateException("'Resources' section is required")]) - - if (not all(isinstance(sam_resource, dict) for sam_resource in sam_template["Resources"].values())): - raise InvalidDocumentException( - [InvalidTemplateException( - "All 'Resources' must be Objects. If you're using YAML, this may be an " - "indentation issue." - )]) + [ + InvalidTemplateException( + "All 'Resources' must be Objects. If you're using YAML, this may be an " "indentation issue." + ) + ] + ) sam_template_instance = SamTemplate(sam_template) @@ -40,11 +44,15 @@ def _validate(self, sam_template, parameter_values): # NOTE: Properties isn't required for SimpleTable, so we can't check # `not isinstance(sam_resources.get("Properties"), dict)` as this would be a breaking change. # sam_resource.properties defaults to {} in SamTemplate init - if (not isinstance(sam_resource.properties, dict)): + if not isinstance(sam_resource.properties, dict): raise InvalidDocumentException( - [InvalidResourceException(resource_logical_id, - "All 'Resources' must be Objects and have a 'Properties' Object. If " - "you're using YAML, this may be an indentation issue." - )]) + [ + InvalidResourceException( + resource_logical_id, + "All 'Resources' must be Objects and have a 'Properties' Object. If " + "you're using YAML, this may be an indentation issue.", + ) + ] + ) SamTemplateValidator.validate(sam_template) diff --git a/samtranslator/plugins/__init__.py b/samtranslator/plugins/__init__.py index 27b49bd4f5..3209100285 100644 --- a/samtranslator/plugins/__init__.py +++ b/samtranslator/plugins/__init__.py @@ -125,8 +125,9 @@ def act(self, event, *args, **kwargs): for plugin in self._plugins: if not hasattr(plugin, method_name): - raise NameError("'{}' method is not found in the plugin with name '{}'" - .format(method_name, plugin.name)) + raise NameError( + "'{}' method is not found in the plugin with name '{}'".format(method_name, plugin.name) + ) try: getattr(plugin, method_name)(*args, **kwargs) @@ -150,6 +151,7 @@ class LifeCycleEvents(Enum): """ Enum of LifeCycleEvents """ + before_transform_template = "before_transform_template" before_transform_resource = "before_transform_resource" after_transform_template = "after_transform_template" diff --git a/samtranslator/plugins/api/default_definition_body_plugin.py b/samtranslator/plugins/api/default_definition_body_plugin.py index f3265dadb9..8ff6c4c28e 100644 --- a/samtranslator/plugins/api/default_definition_body_plugin.py +++ b/samtranslator/plugins/api/default_definition_body_plugin.py @@ -32,16 +32,16 @@ def on_before_transform_template(self, template_dict): for api_type in [SamResourceType.Api.value, SamResourceType.HttpApi.value]: for logicalId, api in template.iterate(api_type): - if api.properties.get('DefinitionBody') or api.properties.get('DefinitionUri'): + if api.properties.get("DefinitionBody") or api.properties.get("DefinitionUri"): continue if api_type is SamResourceType.HttpApi.value: # If "Properties" is not set in the template, set them here if not api.properties: template.set(logicalId, api) - api.properties['DefinitionBody'] = OpenApiEditor.gen_skeleton() + api.properties["DefinitionBody"] = OpenApiEditor.gen_skeleton() if api_type is SamResourceType.Api.value: - api.properties['DefinitionBody'] = SwaggerEditor.gen_skeleton() + api.properties["DefinitionBody"] = SwaggerEditor.gen_skeleton() - api.properties['__MANAGE_SWAGGER'] = True + api.properties["__MANAGE_SWAGGER"] = True diff --git a/samtranslator/plugins/api/implicit_api_plugin.py b/samtranslator/plugins/api/implicit_api_plugin.py index 5ddf1b1143..8d8895e645 100644 --- a/samtranslator/plugins/api/implicit_api_plugin.py +++ b/samtranslator/plugins/api/implicit_api_plugin.py @@ -42,8 +42,9 @@ def __init__(self, name): self._setup_api_properties() def _setup_api_properties(self): - raise NotImplementedError("Method _setup_api_properties() must be implemented in a " - "subclass of ImplicitApiPlugin") + raise NotImplementedError( + "Method _setup_api_properties() must be implemented in a " "subclass of ImplicitApiPlugin" + ) def on_before_transform_template(self, template_dict): """ @@ -102,10 +103,11 @@ def _get_api_events(self, function): } """ - if not (function.valid() and - isinstance(function.properties, dict) and - isinstance(function.properties.get("Events"), dict) - ): + if not ( + function.valid() + and isinstance(function.properties, dict) + and isinstance(function.properties.get("Events"), dict) + ): # Function resource structure is invalid. return {} @@ -127,8 +129,9 @@ def _process_api_events(self, function, api_events, template, condition=None): :param SamTemplate template: SAM Template where Serverless::Api resources can be found :param str condition: optional; this is the condition that is on the function with the API event """ - raise NotImplementedError("Method _setup_api_properties() must be implemented in a " - "subclass of ImplicitApiPlugin") + raise NotImplementedError( + "Method _setup_api_properties() must be implemented in a " "subclass of ImplicitApiPlugin" + ) def _add_implicit_api_id_if_necessary(self, event_properties): """ @@ -138,8 +141,9 @@ def _add_implicit_api_id_if_necessary(self, event_properties): :param dict event_properties: Dictionary of event properties """ - raise NotImplementedError("Method _setup_api_properties() must be implemented in a " - "subclass of ImplicitApiPlugin") + raise NotImplementedError( + "Method _setup_api_properties() must be implemented in a " "subclass of ImplicitApiPlugin" + ) def _add_api_to_swagger(self, event_id, event_properties, template): """ @@ -156,15 +160,18 @@ def _add_api_to_swagger(self, event_id, event_properties, template): # RestApiId is not pointing to a valid API resource if isinstance(api_id, dict) or not template.get(api_id): - raise InvalidEventException(event_id, - "RestApiId must be a valid reference to an 'AWS::Serverless::Api' resource " - "in same template") + raise InvalidEventException( + event_id, + "RestApiId must be a valid reference to an 'AWS::Serverless::Api' resource " "in same template", + ) # Make sure Swagger is valid resource = template.get(api_id) - if not (resource and - isinstance(resource.properties, dict) and - self.editor.is_valid(resource.properties.get("DefinitionBody"))): + if not ( + resource + and isinstance(resource.properties, dict) + and self.editor.is_valid(resource.properties.get("DefinitionBody")) + ): # This does not have an inline Swagger. Nothing can be done about it. return @@ -212,23 +219,28 @@ def _maybe_add_condition_to_implicit_api(self, template_dict): # Add a condition to the API resource IFF all of its resource+methods are associated with serverless functions # containing conditions. implicit_api_conditions = self.api_conditions[self.implicit_api_logical_id] - all_resource_method_conditions = set([condition - for path, method_conditions in implicit_api_conditions.items() - for method, condition in method_conditions.items()]) + all_resource_method_conditions = set( + [ + condition + for path, method_conditions in implicit_api_conditions.items() + for method, condition in method_conditions.items() + ] + ) at_least_one_resource_method = len(all_resource_method_conditions) > 0 all_resource_methods_contain_conditions = None not in all_resource_method_conditions if at_least_one_resource_method and all_resource_methods_contain_conditions: - implicit_api_resource = template_dict.get('Resources').get(self.implicit_api_logical_id) + implicit_api_resource = template_dict.get("Resources").get(self.implicit_api_logical_id) if len(all_resource_method_conditions) == 1: condition = all_resource_method_conditions.pop() - implicit_api_resource['Condition'] = condition + implicit_api_resource["Condition"] = condition else: # If multiple functions with multiple different conditions reference the Implicit Api, we need to # aggregate those conditions in order to conditionally create the Implicit Api. See RFC: # https://github.com/awslabs/serverless-application-model/issues/758 - implicit_api_resource['Condition'] = self.implicit_api_condition + implicit_api_resource["Condition"] = self.implicit_api_condition self._add_combined_condition_to_template( - template_dict, self.implicit_api_condition, all_resource_method_conditions) + template_dict, self.implicit_api_condition, all_resource_method_conditions + ) def _add_combined_condition_to_template(self, template_dict, condition_name, conditions_to_combine): """ @@ -241,9 +253,9 @@ def _add_combined_condition_to_template(self, template_dict, condition_name, con """ # defensive precondition check if not conditions_to_combine or len(conditions_to_combine) < 2: - raise ValueError('conditions_to_combine must have at least 2 conditions') + raise ValueError("conditions_to_combine must have at least 2 conditions") - template_conditions = template_dict.setdefault('Conditions', {}) + template_conditions = template_dict.setdefault("Conditions", {}) new_template_conditions = make_combined_condition(sorted(list(conditions_to_combine)), condition_name) for name, definition in new_template_conditions.items(): template_conditions[name] = definition @@ -261,7 +273,7 @@ def _maybe_add_conditions_to_implicit_api_paths(self, template): """ for api_id, api in template.iterate(self.api_type): - if not api.properties.get('__MANAGE_SWAGGER'): + if not api.properties.get("__MANAGE_SWAGGER"): continue swagger = api.properties.get("DefinitionBody") @@ -279,7 +291,8 @@ def _maybe_add_conditions_to_implicit_api_paths(self, template): else: path_condition_name = self._path_condition_name(api_id, path) self._add_combined_condition_to_template( - template.template_dict, path_condition_name, all_method_conditions) + template.template_dict, path_condition_name, all_method_conditions + ) editor.make_path_conditional(path, path_condition_name) api.properties["DefinitionBody"] = self._get_api_definition_from_editor(editor) # TODO make static method @@ -289,8 +302,9 @@ def _get_api_definition_from_editor(self, editor): """ Required function that returns the api body from the respective editor """ - raise NotImplementedError("Method _setup_api_properties() must be implemented in a " - "subclass of ImplicitApiPlugin") + raise NotImplementedError( + "Method _setup_api_properties() must be implemented in a " "subclass of ImplicitApiPlugin" + ) def _path_condition_name(self, api_id, path): """ @@ -299,8 +313,8 @@ def _path_condition_name(self, api_id, path): # only valid characters for CloudFormation logical id are [A-Za-z0-9], but swagger paths can contain # slashes and curly braces for templated params, e.g., /foo/{customerId}. So we'll replace # non-alphanumeric characters. - path_logical_id = path.replace('/', 'SLASH').replace('{', 'OB').replace('}', 'CB') - return '{}{}PathCondition'.format(api_id, path_logical_id) + path_logical_id = path.replace("/", "SLASH").replace("{", "OB").replace("}", "CB") + return "{}{}PathCondition".format(api_id, path_logical_id) def _maybe_remove_implicit_api(self, template): """ @@ -326,5 +340,6 @@ def _generate_implicit_api_resource(self): """ Helper function implemented by child classes that create a new implicit API resource """ - raise NotImplementedError("Method _setup_api_properties() must be implemented in a " - "subclass of ImplicitApiPlugin") + raise NotImplementedError( + "Method _setup_api_properties() must be implemented in a " "subclass of ImplicitApiPlugin" + ) diff --git a/samtranslator/plugins/api/implicit_http_api_plugin.py b/samtranslator/plugins/api/implicit_http_api_plugin.py index 1a12e980b3..2e53190361 100644 --- a/samtranslator/plugins/api/implicit_http_api_plugin.py +++ b/samtranslator/plugins/api/implicit_http_api_plugin.py @@ -78,9 +78,8 @@ def _process_api_events(self, function, api_events, template, condition=None): raise InvalidEventException(logicalId, "Api Event must have a String specified for '{}'.".format(key)) # !Ref is resolved by this time. If it is still a dict, we can't parse/use this Api. - if (isinstance(api_id, dict)): - raise InvalidEventException(logicalId, - "Api Event must reference an Api in the same template.") + if isinstance(api_id, dict): + raise InvalidEventException(logicalId, "Api Event must reference an Api in the same template.") api_dict = self.api_conditions.setdefault(api_id, {}) method_conditions = api_dict.setdefault(path, {}) @@ -133,8 +132,8 @@ def __init__(self): "DefinitionBody": open_api, # Internal property that means Event source code can add Events. Used only for implicit APIs, to # prevent back compatibility issues for explicit APIs - "__MANAGE_SWAGGER": True - } + "__MANAGE_SWAGGER": True, + }, } super(ImplicitHttpApiResource, self).__init__(resource) diff --git a/samtranslator/plugins/api/implicit_rest_api_plugin.py b/samtranslator/plugins/api/implicit_rest_api_plugin.py index e6079de5dc..ece4980fd0 100644 --- a/samtranslator/plugins/api/implicit_rest_api_plugin.py +++ b/samtranslator/plugins/api/implicit_rest_api_plugin.py @@ -72,17 +72,14 @@ def _process_api_events(self, function, api_events, template, condition=None): except KeyError as e: raise InvalidEventException(logicalId, "Event is missing key {}.".format(e)) - if (not isinstance(path, six.string_types)): - raise InvalidEventException(logicalId, - "Api Event must have a String specified for 'Path'.") - if (not isinstance(method, six.string_types)): - raise InvalidEventException(logicalId, - "Api Event must have a String specified for 'Method'.") + if not isinstance(path, six.string_types): + raise InvalidEventException(logicalId, "Api Event must have a String specified for 'Path'.") + if not isinstance(method, six.string_types): + raise InvalidEventException(logicalId, "Api Event must have a String specified for 'Method'.") # !Ref is resolved by this time. If it is still a dict, we can't parse/use this Api. - if (isinstance(api_id, dict)): - raise InvalidEventException(logicalId, - "Api Event must reference an Api in the same template.") + if isinstance(api_id, dict): + raise InvalidEventException(logicalId, "Api Event must reference an Api in the same template.") api_dict = self.api_conditions.setdefault(api_id, {}) method_conditions = api_dict.setdefault(path, {}) @@ -131,17 +128,15 @@ def __init__(self): resource = { "Type": SamResourceType.Api.value, "Properties": { - # Because we set the StageName to be constant value here, customers cannot override StageName with # Globals. This is because, if a property is specified in both Globals and the resource, the resource # one takes precedence. "StageName": "Prod", - "DefinitionBody": swagger, # Internal property that means Event source code can add Events. Used only for implicit APIs, to # prevent back compatibility issues for explicit APIs - "__MANAGE_SWAGGER": True - } + "__MANAGE_SWAGGER": True, + }, } super(ImplicitApiResource, self).__init__(resource) diff --git a/samtranslator/plugins/application/serverless_app_plugin.py b/samtranslator/plugins/application/serverless_app_plugin.py index feadfa1090..246125ba56 100644 --- a/samtranslator/plugins/application/serverless_app_plugin.py +++ b/samtranslator/plugins/application/serverless_app_plugin.py @@ -36,10 +36,10 @@ class ServerlessAppPlugin(BasePlugin): # CloudFormation times out on transforms after 2 minutes, so setting this # timeout below that to leave some buffer TEMPLATE_WAIT_TIMEOUT_SECONDS = 105 - APPLICATION_ID_KEY = 'ApplicationId' - SEMANTIC_VERSION_KEY = 'SemanticVersion' - LOCATION_KEY = 'Location' - TEMPLATE_URL_KEY = 'TemplateUrl' + APPLICATION_ID_KEY = "ApplicationId" + SEMANTIC_VERSION_KEY = "SemanticVersion" + LOCATION_KEY = "Location" + TEMPLATE_URL_KEY = "TemplateUrl" def __init__(self, sar_client=None, wait_for_template_active_status=False, validate_only=False, parameters={}): """ @@ -75,7 +75,7 @@ def on_before_transform_template(self, template_dict): :return: Nothing """ template = SamTemplate(template_dict) - intrinsic_resolvers = self._get_intrinsic_resolvers(template_dict.get('Mappings', {})) + intrinsic_resolvers = self._get_intrinsic_resolvers(template_dict.get("Mappings", {})) service_call = None if self._validate_only: @@ -87,11 +87,13 @@ def on_before_transform_template(self, template_dict): # Handle these cases in the on_before_transform_resource event continue - app_id = self._replace_value(app.properties[self.LOCATION_KEY], - self.APPLICATION_ID_KEY, intrinsic_resolvers) + app_id = self._replace_value( + app.properties[self.LOCATION_KEY], self.APPLICATION_ID_KEY, intrinsic_resolvers + ) - semver = self._replace_value(app.properties[self.LOCATION_KEY], - self.SEMANTIC_VERSION_KEY, intrinsic_resolvers) + semver = self._replace_value( + app.properties[self.LOCATION_KEY], self.SEMANTIC_VERSION_KEY, intrinsic_resolvers + ) if isinstance(app_id, dict) or isinstance(semver, dict): key = (json.dumps(app_id), json.dumps(semver)) @@ -104,7 +106,7 @@ def on_before_transform_template(self, template_dict): try: # Lazy initialization of the client- create it when it is needed if not self._sar_client: - self._sar_client = boto3.client('serverlessrepo') + self._sar_client = boto3.client("serverlessrepo") service_call(app_id, semver, key, logical_id) except InvalidResourceException as e: # Catch all InvalidResourceExceptions, raise those in the before_resource_transform target. @@ -116,8 +118,10 @@ def _replace_value(self, input_dict, key, intrinsic_resolvers): return value def _get_intrinsic_resolvers(self, mappings): - return [IntrinsicsResolver(self._parameters), - IntrinsicsResolver(mappings, {FindInMapAction.intrinsic_name: FindInMapAction()})] + return [ + IntrinsicsResolver(self._parameters), + IntrinsicsResolver(mappings, {FindInMapAction.intrinsic_name: FindInMapAction()}), + ] def _resolve_location_value(self, value, intrinsic_resolvers): resolved_value = copy.deepcopy(value) @@ -131,12 +135,14 @@ def _can_process_application(self, app): :param dict app: the application and its properties """ - return (self.LOCATION_KEY in app.properties and - isinstance(app.properties[self.LOCATION_KEY], dict) and - self.APPLICATION_ID_KEY in app.properties[self.LOCATION_KEY] and - app.properties[self.LOCATION_KEY][self.APPLICATION_ID_KEY] is not None and - self.SEMANTIC_VERSION_KEY in app.properties[self.LOCATION_KEY] and - app.properties[self.LOCATION_KEY][self.SEMANTIC_VERSION_KEY] is not None) + return ( + self.LOCATION_KEY in app.properties + and isinstance(app.properties[self.LOCATION_KEY], dict) + and self.APPLICATION_ID_KEY in app.properties[self.LOCATION_KEY] + and app.properties[self.LOCATION_KEY][self.APPLICATION_ID_KEY] is not None + and self.SEMANTIC_VERSION_KEY in app.properties[self.LOCATION_KEY] + and app.properties[self.LOCATION_KEY][self.SEMANTIC_VERSION_KEY] is not None + ) def _handle_get_application_request(self, app_id, semver, key, logical_id): """ @@ -150,17 +156,17 @@ def _handle_get_application_request(self, app_id, semver, key, logical_id): :param string key: The dictionary key consisting of (ApplicationId, SemanticVersion) :param string logical_id: the logical_id of this application resource """ - get_application = (lambda app_id, semver: self._sar_client.get_application( - ApplicationId=self._sanitize_sar_str_param(app_id), - SemanticVersion=self._sanitize_sar_str_param(semver))) + get_application = lambda app_id, semver: self._sar_client.get_application( + ApplicationId=self._sanitize_sar_str_param(app_id), SemanticVersion=self._sanitize_sar_str_param(semver) + ) try: self._sar_service_call(get_application, logical_id, app_id, semver) - self._applications[key] = {'Available'} + self._applications[key] = {"Available"} except EndpointConnectionError as e: # No internet connection. Don't break verification, but do show a warning. warning_message = "{}. Unable to verify access to {}/{}.".format(e, app_id, semver) LOG.warning(warning_message) - self._applications[key] = {'Unable to verify'} + self._applications[key] = {"Unable to verify"} def _handle_create_cfn_template_request(self, app_id, semver, key, logical_id): """ @@ -171,14 +177,13 @@ def _handle_create_cfn_template_request(self, app_id, semver, key, logical_id): :param string key: The dictionary key consisting of (ApplicationId, SemanticVersion) :param string logical_id: the logical_id of this application resource """ - create_cfn_template = (lambda app_id, semver: self._sar_client.create_cloud_formation_template( - ApplicationId=self._sanitize_sar_str_param(app_id), - SemanticVersion=self._sanitize_sar_str_param(semver) - )) + create_cfn_template = lambda app_id, semver: self._sar_client.create_cloud_formation_template( + ApplicationId=self._sanitize_sar_str_param(app_id), SemanticVersion=self._sanitize_sar_str_param(semver) + ) response = self._sar_service_call(create_cfn_template, logical_id, app_id, semver) self._applications[key] = response[self.TEMPLATE_URL_KEY] - if response['Status'] != "ACTIVE": - self._in_progress_templates.append((response[self.APPLICATION_ID_KEY], response['TemplateId'])) + if response["Status"] != "ACTIVE": + self._in_progress_templates.append((response[self.APPLICATION_ID_KEY], response["TemplateId"])) def _sanitize_sar_str_param(self, param): """ @@ -219,8 +224,9 @@ def on_before_transform_resource(self, logical_id, resource_type, resource_prope return # If it is a dictionary, check for other required parameters - self._check_for_dictionary_key(logical_id, resource_properties[self.LOCATION_KEY], - [self.APPLICATION_ID_KEY, self.SEMANTIC_VERSION_KEY]) + self._check_for_dictionary_key( + logical_id, resource_properties[self.LOCATION_KEY], [self.APPLICATION_ID_KEY, self.SEMANTIC_VERSION_KEY] + ) app_id = resource_properties[self.LOCATION_KEY].get(self.APPLICATION_ID_KEY) @@ -228,8 +234,11 @@ def on_before_transform_resource(self, logical_id, resource_type, resource_prope raise InvalidResourceException(logical_id, "Property 'ApplicationId' cannot be blank.") if isinstance(app_id, dict): - raise InvalidResourceException(logical_id, "Property 'ApplicationId' cannot be resolved. Only FindInMap " - "and Ref intrinsic functions are supported.") + raise InvalidResourceException( + logical_id, + "Property 'ApplicationId' cannot be resolved. Only FindInMap " + "and Ref intrinsic functions are supported.", + ) semver = resource_properties[self.LOCATION_KEY].get(self.SEMANTIC_VERSION_KEY) @@ -237,8 +246,11 @@ def on_before_transform_resource(self, logical_id, resource_type, resource_prope raise InvalidResourceException(logical_id, "Property 'SemanticVersion' cannot be blank.") if isinstance(semver, dict): - raise InvalidResourceException(logical_id, "Property 'SemanticVersion' cannot be resolved. Only FindInMap " - "and Ref intrinsic functions are supported.") + raise InvalidResourceException( + logical_id, + "Property 'SemanticVersion' cannot be resolved. Only FindInMap " + "and Ref intrinsic functions are supported.", + ) key = (app_id, semver) @@ -261,8 +273,9 @@ def _check_for_dictionary_key(self, logical_id, dictionary, keys): """ for key in keys: if key not in dictionary: - raise InvalidResourceException(logical_id, 'Resource is missing the required [{}] ' - 'property.'.format(key)) + raise InvalidResourceException( + logical_id, "Resource is missing the required [{}] " "property.".format(key) + ) def on_after_transform_template(self, template): """ @@ -281,10 +294,10 @@ def on_after_transform_template(self, template): # Check each resource to make sure it's active for application_id, template_id in temp: - get_cfn_template = (lambda application_id, template_id: - self._sar_client.get_cloud_formation_template( - ApplicationId=self._sanitize_sar_str_param(application_id), - TemplateId=self._sanitize_sar_str_param(template_id))) + get_cfn_template = lambda application_id, template_id: self._sar_client.get_cloud_formation_template( + ApplicationId=self._sanitize_sar_str_param(application_id), + TemplateId=self._sanitize_sar_str_param(template_id), + ) response = self._sar_service_call(get_cfn_template, application_id, application_id, template_id) self._handle_get_cfn_template_response(response, application_id, template_id) @@ -298,8 +311,9 @@ def on_after_transform_template(self, template): # Not all templates reached active status if len(self._in_progress_templates) != 0: application_ids = [items[0] for items in self._in_progress_templates] - raise InvalidResourceException(application_ids, "Timed out waiting for nested stack templates " - "to reach ACTIVE status.") + raise InvalidResourceException( + application_ids, "Timed out waiting for nested stack templates " "to reach ACTIVE status." + ) def _handle_get_cfn_template_response(self, response, application_id, template_id): """ @@ -309,12 +323,14 @@ def _handle_get_cfn_template_response(self, response, application_id, template_i :param string application_id: the ApplicationId :param string template_id: the unique TemplateId for this application """ - status = response['Status'] + status = response["Status"] if status != "ACTIVE": # Other options are PREPARING and EXPIRED. - if status == 'EXPIRED': - message = ("Template for {} with id {} returned status: {}. Cannot access an expired " - "template.".format(application_id, template_id, status)) + if status == "EXPIRED": + message = ( + "Template for {} with id {} returned status: {}. Cannot access an expired " + "template.".format(application_id, template_id, status) + ) raise InvalidResourceException(application_id, message) self._in_progress_templates.append((application_id, template_id)) @@ -332,9 +348,9 @@ def _sar_service_call(self, service_call_lambda, logical_id, *args): LOG.info(response) return response except ClientError as e: - error_code = e.response['Error']['Code'] - if error_code in ('AccessDeniedException', 'NotFoundException'): - raise InvalidResourceException(logical_id, e.response['Error']['Message']) + error_code = e.response["Error"]["Code"] + if error_code in ("AccessDeniedException", "NotFoundException"): + raise InvalidResourceException(logical_id, e.response["Error"]["Message"]) # 'ForbiddenException'- SAR rejects connection LOG.exception(e) diff --git a/samtranslator/plugins/exceptions.py b/samtranslator/plugins/exceptions.py index 0bb65c893e..c147fcd311 100644 --- a/samtranslator/plugins/exceptions.py +++ b/samtranslator/plugins/exceptions.py @@ -5,10 +5,11 @@ class InvalidPluginException(Exception): plugin_name -- name of the plugin that caused this error message -- explanation of the error """ + def __init__(self, plugin_name, message): self._plugin_name = plugin_name self._message = message @property def message(self): - return 'The {} plugin is invalid. {}'.format(self._plugin_name, self._message) + return "The {} plugin is invalid. {}".format(self._plugin_name, self._message) diff --git a/samtranslator/plugins/globals/globals.py b/samtranslator/plugins/globals/globals.py index dc3803caee..568620333f 100644 --- a/samtranslator/plugins/globals/globals.py +++ b/samtranslator/plugins/globals/globals.py @@ -39,9 +39,8 @@ class Globals(object): "ReservedConcurrentExecutions", "ProvisionedConcurrencyConfig", "AssumeRolePolicyDocument", - "EventInvokeConfig" + "EventInvokeConfig", ], - # Everything except # DefinitionBody: because its hard to reason about merge of Swagger dictionaries # StageName: Because StageName cannot be overridden for Implicit APIs because of the current plugin @@ -63,19 +62,10 @@ class Globals(object): "CanarySetting", "TracingEnabled", "OpenApiVersion", - "Domain" - ], - - SamResourceType.HttpApi.value: [ - "Auth", - "AccessLogSettings", - "StageVariables", - "Tags" + "Domain", ], - - SamResourceType.SimpleTable.value: [ - "SSESpecification" - ] + SamResourceType.HttpApi.value: ["Auth", "AccessLogSettings", "StageVariables", "Tags"], + SamResourceType.SimpleTable.value: ["SSESpecification"], } def __init__(self, template): @@ -84,8 +74,9 @@ def __init__(self, template): :param dict template: SAM template to be parsed """ - self.supported_resource_section_names = ([x.replace(self._RESOURCE_PREFIX, "") - for x in self.supported_properties.keys()]) + self.supported_resource_section_names = [ + x.replace(self._RESOURCE_PREFIX, "") for x in self.supported_properties.keys() + ] # Sort the names for stability in list ordering self.supported_resource_section_names.sort() @@ -145,17 +136,21 @@ def fix_openapi_definitions(cls, template): for _, resource in resources.items(): if ("Type" in resource) and (resource["Type"] == cls._API_TYPE): properties = resource["Properties"] - if (cls._OPENAPIVERSION in properties) and (cls._MANAGE_SWAGGER in properties) and \ - SwaggerEditor.safe_compare_regex_with_string( - SwaggerEditor.get_openapi_version_3_regex(), properties[cls._OPENAPIVERSION]): + if ( + (cls._OPENAPIVERSION in properties) + and (cls._MANAGE_SWAGGER in properties) + and SwaggerEditor.safe_compare_regex_with_string( + SwaggerEditor.get_openapi_version_3_regex(), properties[cls._OPENAPIVERSION] + ) + ): if not isinstance(properties[cls._OPENAPIVERSION], string_types): properties[cls._OPENAPIVERSION] = str(properties[cls._OPENAPIVERSION]) resource["Properties"] = properties if "DefinitionBody" in properties: - definition_body = properties['DefinitionBody'] - definition_body['openapi'] = properties[cls._OPENAPIVERSION] - if definition_body.get('swagger'): - del definition_body['swagger'] + definition_body = properties["DefinitionBody"] + definition_body["openapi"] = properties[cls._OPENAPIVERSION] + if definition_body.get("swagger"): + del definition_body["swagger"] def _parse(self, globals_dict): """ @@ -168,18 +163,21 @@ def _parse(self, globals_dict): globals = {} if not isinstance(globals_dict, dict): - raise InvalidGlobalsSectionException(self._KEYWORD, - "It must be a non-empty dictionary".format(self._KEYWORD)) + raise InvalidGlobalsSectionException( + self._KEYWORD, "It must be a non-empty dictionary".format(self._KEYWORD) + ) for section_name, properties in globals_dict.items(): resource_type = self._make_resource_type(section_name) if resource_type not in self.supported_properties: - raise InvalidGlobalsSectionException(self._KEYWORD, - "'{section}' is not supported. " - "Must be one of the following values - {supported}" - .format(section=section_name, - supported=self.supported_resource_section_names)) + raise InvalidGlobalsSectionException( + self._KEYWORD, + "'{section}' is not supported. " + "Must be one of the following values - {supported}".format( + section=section_name, supported=self.supported_resource_section_names + ), + ) if not isinstance(properties, dict): raise InvalidGlobalsSectionException(self._KEYWORD, "Value of ${section} must be a dictionary") @@ -187,10 +185,13 @@ def _parse(self, globals_dict): for key, value in properties.items(): supported = self.supported_properties[resource_type] if key not in supported: - raise InvalidGlobalsSectionException(self._KEYWORD, - "'{key}' is not a supported property of '{section}'. " - "Must be one of the following values - {supported}" - .format(key=key, section=section_name, supported=supported)) + raise InvalidGlobalsSectionException( + self._KEYWORD, + "'{key}' is not a supported property of '{section}'. " + "Must be one of the following values - {supported}".format( + key=key, section=section_name, supported=supported + ), + ) # Store all Global properties in a map with key being the AWS::Serverless::* resource type globals[resource_type] = GlobalProperties(properties) @@ -360,7 +361,8 @@ def _do_merge(self, global_value, local_value): else: raise TypeError( - "Unsupported type of objects. GlobalType={}, LocalType={}".format(token_global, token_local)) + "Unsupported type of objects. GlobalType={}, LocalType={}".format(token_global, token_local) + ) def _merge_lists(self, global_list, local_list): """ @@ -436,6 +438,7 @@ class TOKEN: """ Enum of tokens used in the merging """ + PRIMITIVE = "primitive" DICT = "dict" LIST = "list" diff --git a/samtranslator/plugins/globals/globals_plugin.py b/samtranslator/plugins/globals/globals_plugin.py index 17f42e3e75..274db917e9 100644 --- a/samtranslator/plugins/globals/globals_plugin.py +++ b/samtranslator/plugins/globals/globals_plugin.py @@ -1,4 +1,3 @@ - from samtranslator.public.sdk.template import SamTemplate from samtranslator.public.plugins import BasePlugin from samtranslator.public.exceptions import InvalidDocumentException diff --git a/samtranslator/plugins/policies/policy_templates_plugin.py b/samtranslator/plugins/policies/policy_templates_plugin.py index 50aded6478..041291715a 100644 --- a/samtranslator/plugins/policies/policy_templates_plugin.py +++ b/samtranslator/plugins/policies/policy_templates_plugin.py @@ -73,20 +73,20 @@ def _process_intrinsic_if_policy_template(self, logical_id, policy_entry): then_statement = intrinsic_if["Fn::If"][1] else_statement = intrinsic_if["Fn::If"][2] - processed_then_statement = then_statement \ - if is_intrinsic_no_value(then_statement) \ + processed_then_statement = ( + then_statement + if is_intrinsic_no_value(then_statement) else self._process_policy_template(logical_id, then_statement) + ) - processed_else_statement = else_statement \ - if is_intrinsic_no_value(else_statement) \ + processed_else_statement = ( + else_statement + if is_intrinsic_no_value(else_statement) else self._process_policy_template(logical_id, else_statement) + ) processed_intrinsic_if = { - "Fn::If": [ - policy_entry.data["Fn::If"][0], - processed_then_statement, - processed_else_statement - ] + "Fn::If": [policy_entry.data["Fn::If"][0], processed_then_statement, processed_else_statement] } return processed_intrinsic_if @@ -106,9 +106,9 @@ def _process_policy_template(self, logical_id, template_data): # Exception's message will give lot of specific details raise InvalidResourceException(logical_id, str(ex)) except InvalidParameterValues: - raise InvalidResourceException(logical_id, - "Must specify valid parameter values for policy template '{}'" - .format(template_name)) + raise InvalidResourceException( + logical_id, "Must specify valid parameter values for policy template '{}'".format(template_name) + ) def _is_supported(self, resource_type): """ diff --git a/samtranslator/policy_template_processor/exceptions.py b/samtranslator/policy_template_processor/exceptions.py index 8a3de49b84..ba5fdb7a49 100644 --- a/samtranslator/policy_template_processor/exceptions.py +++ b/samtranslator/policy_template_processor/exceptions.py @@ -1,9 +1,8 @@ - - class TemplateNotFoundException(Exception): """ Exception raised when a template with given name is not found """ + def __init__(self, template_name): super(TemplateNotFoundException, self).__init__("Template with name '{}' is not found".format(template_name)) @@ -12,6 +11,7 @@ class InsufficientParameterValues(Exception): """ Exception raised when not every parameter in the template is given a value. """ + def __init__(self, message): super(InsufficientParameterValues, self).__init__(message) @@ -20,5 +20,6 @@ class InvalidParameterValues(Exception): """ Exception raised when parameter values passed to this template is invalid """ + def __init__(self, message): super(InvalidParameterValues, self).__init__(message) diff --git a/samtranslator/policy_template_processor/template.py b/samtranslator/policy_template_processor/template.py index f84cf3445f..06d0f824c2 100644 --- a/samtranslator/policy_template_processor/template.py +++ b/samtranslator/policy_template_processor/template.py @@ -42,19 +42,21 @@ def to_statement(self, parameter_values): missing = self.missing_parameter_values(parameter_values) if len(missing) > 0: # str() of elements of list to prevent any `u` prefix from being displayed in user-facing error message - raise InsufficientParameterValues("Following required parameters of template '{}' don't have values: {}" - .format(self.name, [str(m) for m in missing])) + raise InsufficientParameterValues( + "Following required parameters of template '{}' don't have values: {}".format( + self.name, [str(m) for m in missing] + ) + ) # Select only necessary parameter_values. this is to prevent malicious or accidental # injection of values for parameters not intended in the template. This is important because "Ref" resolution # will substitute any references for which a value is provided. - necessary_parameter_values = {name: value for name, value in parameter_values.items() - if name in self.parameters} + necessary_parameter_values = { + name: value for name, value in parameter_values.items() if name in self.parameters + } # Only "Ref" is supported - supported_intrinsics = { - RefAction.intrinsic_name: RefAction() - } + supported_intrinsics = {RefAction.intrinsic_name: RefAction()} resolver = IntrinsicsResolver(necessary_parameter_values, supported_intrinsics) definition_copy = copy.deepcopy(self.definition) diff --git a/samtranslator/public/__init__.py b/samtranslator/public/__init__.py index 7035de4cda..32aa984b25 100644 --- a/samtranslator/public/__init__.py +++ b/samtranslator/public/__init__.py @@ -3,4 +3,3 @@ # Root of the SAM package where we expose public classes & methods for other consumers of this SAM Translator to use. # This is essentially our Public API # - diff --git a/samtranslator/public/sdk/parameter.py b/samtranslator/public/sdk/parameter.py index 442e90b404..c881fd8e03 100644 --- a/samtranslator/public/sdk/parameter.py +++ b/samtranslator/public/sdk/parameter.py @@ -1,3 +1,3 @@ # flake8: noqa -from samtranslator.sdk.parameter import SamParameterValues \ No newline at end of file +from samtranslator.sdk.parameter import SamParameterValues diff --git a/samtranslator/public/sdk/resource.py b/samtranslator/public/sdk/resource.py index 8f090c7166..4d24c0c090 100644 --- a/samtranslator/public/sdk/resource.py +++ b/samtranslator/public/sdk/resource.py @@ -1,3 +1,3 @@ # flake8: noqa -from samtranslator.sdk.resource import SamResource, SamResourceType \ No newline at end of file +from samtranslator.sdk.resource import SamResource, SamResourceType diff --git a/samtranslator/public/sdk/template.py b/samtranslator/public/sdk/template.py index 753b237ec3..ec8398b03d 100644 --- a/samtranslator/public/sdk/template.py +++ b/samtranslator/public/sdk/template.py @@ -1,3 +1,3 @@ # flake8: noqa -from samtranslator.sdk.template import SamTemplate \ No newline at end of file +from samtranslator.sdk.template import SamTemplate diff --git a/samtranslator/public/translator.py b/samtranslator/public/translator.py index e98b4608b7..01973961ab 100644 --- a/samtranslator/public/translator.py +++ b/samtranslator/public/translator.py @@ -6,4 +6,3 @@ from samtranslator.translator.translator import Translator from samtranslator.translator.managed_policy_translator import ManagedPolicyLoader - diff --git a/samtranslator/region_configuration.py b/samtranslator/region_configuration.py index 0d474b2e00..e8f7da533e 100644 --- a/samtranslator/region_configuration.py +++ b/samtranslator/region_configuration.py @@ -7,10 +7,7 @@ class RegionConfiguration(object): class abstracts all region/partition specific configuration. """ - partitions = { - "govcloud": "aws-us-gov", - "china": "aws-cn" - } + partitions = {"govcloud": "aws-us-gov", "china": "aws-cn"} @classmethod def is_apigw_edge_configuration_supported(cls): @@ -21,7 +18,4 @@ def is_apigw_edge_configuration_supported(cls): :return: True, if API Gateway does not support Edge configuration """ - return ArnGenerator.get_partition_name() not in [ - cls.partitions["govcloud"], - cls.partitions["china"] - ] + return ArnGenerator.get_partition_name() not in [cls.partitions["govcloud"], cls.partitions["china"]] diff --git a/samtranslator/sdk/parameter.py b/samtranslator/sdk/parameter.py index 57deee442d..161bd95b49 100644 --- a/samtranslator/sdk/parameter.py +++ b/samtranslator/sdk/parameter.py @@ -63,16 +63,16 @@ def add_pseudo_parameter_values(self): Add pseudo parameter values :return: parameter values that have pseudo parameter in it """ - if 'AWS::Region' not in self.parameter_values: - self.parameter_values['AWS::Region'] = boto3.session.Session().region_name + if "AWS::Region" not in self.parameter_values: + self.parameter_values["AWS::Region"] = boto3.session.Session().region_name - if 'AWS::Partition' not in self.parameter_values: + if "AWS::Partition" not in self.parameter_values: region = boto3.session.Session().region_name # neither boto nor botocore has any way of returning the partition value yet - if region.startswith('cn-'): - self.parameter_values['AWS::Partition'] = 'aws-cn' - elif region.startswith('us-gov-'): - self.parameter_values['AWS::Partition'] = 'aws-us-gov' + if region.startswith("cn-"): + self.parameter_values["AWS::Partition"] = "aws-cn" + elif region.startswith("us-gov-"): + self.parameter_values["AWS::Partition"] = "aws-us-gov" else: - self.parameter_values['AWS::Partition'] = 'aws' + self.parameter_values["AWS::Partition"] = "aws" diff --git a/samtranslator/sdk/resource.py b/samtranslator/sdk/resource.py index 8d9d6b4c7b..dc5a04ad04 100644 --- a/samtranslator/sdk/resource.py +++ b/samtranslator/sdk/resource.py @@ -9,6 +9,7 @@ class SamResource(object): Any mutating methods also touch only "Properties" and "Type" attributes of the resource. This allows compatibility with any CloudFormation constructs, like DependsOn, Conditions etc. """ + type = None properties = {} @@ -38,8 +39,7 @@ def valid(self): if self.condition: if not is_str()(self.condition, should_raise=False): - raise InvalidDocumentException([ - InvalidTemplateException("Every Condition member must be a string.")]) + raise InvalidDocumentException([InvalidTemplateException("Every Condition member must be a string.")]) return SamResourceType.has_value(self.type) @@ -58,6 +58,7 @@ class SamResourceType(Enum): """ Enum of supported SAM types """ + Api = "AWS::Serverless::Api" Function = "AWS::Serverless::Function" SimpleTable = "AWS::Serverless::SimpleTable" diff --git a/samtranslator/swagger/swagger.py b/samtranslator/swagger/swagger.py index b8d67dd4a7..9c8d0575ec 100644 --- a/samtranslator/swagger/swagger.py +++ b/samtranslator/swagger/swagger.py @@ -17,13 +17,13 @@ class SwaggerEditor(object): """ _OPTIONS_METHOD = "options" - _X_APIGW_INTEGRATION = 'x-amazon-apigateway-integration' - _X_APIGW_BINARY_MEDIA_TYPES = 'x-amazon-apigateway-binary-media-types' + _X_APIGW_INTEGRATION = "x-amazon-apigateway-integration" + _X_APIGW_BINARY_MEDIA_TYPES = "x-amazon-apigateway-binary-media-types" _CONDITIONAL_IF = "Fn::If" - _X_APIGW_GATEWAY_RESPONSES = 'x-amazon-apigateway-gateway-responses' - _X_APIGW_POLICY = 'x-amazon-apigateway-policy' - _X_ANY_METHOD = 'x-amazon-apigateway-any-method' - _CACHE_KEY_PARAMETERS = 'cacheKeyParameters' + _X_APIGW_GATEWAY_RESPONSES = "x-amazon-apigateway-gateway-responses" + _X_APIGW_POLICY = "x-amazon-apigateway-policy" + _X_ANY_METHOD = "x-amazon-apigateway-any-method" + _CACHE_KEY_PARAMETERS = "cacheKeyParameters" # https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html _ALL_HTTP_METHODS = ["OPTIONS", "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"] _POLICY_TYPE_IAM = "Iam" @@ -47,7 +47,7 @@ def __init__(self, doc): self.security_definitions = self._doc.get("securityDefinitions", {}) self.gateway_responses = self._doc.get(self._X_APIGW_GATEWAY_RESPONSES, {}) self.resource_policy = self._doc.get(self._X_APIGW_POLICY, {}) - self.definitions = self._doc.get('definitions', {}) + self.definitions = self._doc.get("definitions", {}) def get_path(self, path): path_dict = self.paths.get(path) @@ -119,9 +119,11 @@ def has_integration(self, path, method): method = self._normalize_method_name(method) path_dict = self.get_path(path) - return self.has_path(path, method) and \ - isinstance(path_dict[method], dict) and \ - self.method_has_integration(path_dict[method]) # Integration present and non-empty + return ( + self.has_path(path, method) + and isinstance(path_dict[method], dict) + and self.method_has_integration(path_dict[method]) + ) # Integration present and non-empty def add_path(self, path, method=None): """ @@ -138,16 +140,21 @@ def add_path(self, path, method=None): if not isinstance(path_dict, dict): # Either customers has provided us an invalid Swagger, or this class has messed it somehow raise InvalidDocumentException( - [InvalidTemplateException("Value of '{}' path must be a dictionary according to Swagger spec." - .format(path))]) + [ + InvalidTemplateException( + "Value of '{}' path must be a dictionary according to Swagger spec.".format(path) + ) + ] + ) if self._CONDITIONAL_IF in path_dict: path_dict = path_dict[self._CONDITIONAL_IF][1] path_dict.setdefault(method, {}) - def add_lambda_integration(self, path, method, integration_uri, - method_auth_config=None, api_auth_config=None, condition=None): + def add_lambda_integration( + self, path, method, integration_uri, method_auth_config=None, api_auth_config=None, condition=None + ): """ Adds aws_proxy APIGW integration to the given path+method. @@ -169,30 +176,32 @@ def add_lambda_integration(self, path, method, integration_uri, path_dict = self.get_path(path) path_dict[method][self._X_APIGW_INTEGRATION] = { - 'type': 'aws_proxy', - 'httpMethod': 'POST', - 'uri': integration_uri + "type": "aws_proxy", + "httpMethod": "POST", + "uri": integration_uri, } method_auth_config = method_auth_config or {} api_auth_config = api_auth_config or {} - if method_auth_config.get('Authorizer') == 'AWS_IAM' \ - or api_auth_config.get('DefaultAuthorizer') == 'AWS_IAM' and not method_auth_config: - method_invoke_role = method_auth_config.get('InvokeRole') - if not method_invoke_role and 'InvokeRole' in method_auth_config: - method_invoke_role = 'NONE' - api_invoke_role = api_auth_config.get('InvokeRole') - if not api_invoke_role and 'InvokeRole' in api_auth_config: - api_invoke_role = 'NONE' + if ( + method_auth_config.get("Authorizer") == "AWS_IAM" + or api_auth_config.get("DefaultAuthorizer") == "AWS_IAM" + and not method_auth_config + ): + method_invoke_role = method_auth_config.get("InvokeRole") + if not method_invoke_role and "InvokeRole" in method_auth_config: + method_invoke_role = "NONE" + api_invoke_role = api_auth_config.get("InvokeRole") + if not api_invoke_role and "InvokeRole" in api_auth_config: + api_invoke_role = "NONE" credentials = self._generate_integration_credentials( - method_invoke_role=method_invoke_role, - api_invoke_role=api_invoke_role + method_invoke_role=method_invoke_role, api_invoke_role=api_invoke_role ) - if credentials and credentials != 'NONE': - self.paths[path][method][self._X_APIGW_INTEGRATION]['credentials'] = credentials + if credentials and credentials != "NONE": + self.paths[path][method][self._X_APIGW_INTEGRATION]["credentials"] = credentials # If 'responses' key is *not* present, add it with an empty dict as value - path_dict[method].setdefault('responses', {}) + path_dict[method].setdefault("responses", {}) # If a condition is present, wrap all method contents up into the condition if condition: @@ -208,8 +217,8 @@ def _generate_integration_credentials(self, method_invoke_role=None, api_invoke_ return self._get_invoke_role(method_invoke_role or api_invoke_role) def _get_invoke_role(self, invoke_role): - CALLER_CREDENTIALS_ARN = 'arn:aws:iam::*:user/*' - return invoke_role if invoke_role and invoke_role != 'CALLER_CREDENTIALS' else CALLER_CREDENTIALS_ARN + CALLER_CREDENTIALS_ARN = "arn:aws:iam::*:user/*" + return invoke_role if invoke_role and invoke_role != "CALLER_CREDENTIALS" else CALLER_CREDENTIALS_ARN def iter_on_path(self): """ @@ -222,8 +231,9 @@ def iter_on_path(self): for path, value in self.paths.items(): yield path - def add_cors(self, path, allowed_origins, allowed_headers=None, allowed_methods=None, max_age=None, - allow_credentials=None): + def add_cors( + self, path, allowed_origins, allowed_headers=None, allowed_methods=None, max_age=None, allow_credentials=None + ): """ Add CORS configuration to this path. Specifically, we will add a OPTIONS response config to the Swagger that will return headers required for CORS. Since SAM uses aws_proxy integration, we cannot inject the headers @@ -267,18 +277,17 @@ def add_cors(self, path, allowed_origins, allowed_headers=None, allowed_methods= # Add the Options method and the CORS response self.add_path(path, self._OPTIONS_METHOD) - self.get_path(path)[self._OPTIONS_METHOD] = self._options_method_response_for_cors(allowed_origins, - allowed_headers, - allowed_methods, - max_age, - allow_credentials) + self.get_path(path)[self._OPTIONS_METHOD] = self._options_method_response_for_cors( + allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials + ) def add_binary_media_types(self, binary_media_types): - bmt = json.loads(json.dumps(binary_media_types).replace('~1', '/')) + bmt = json.loads(json.dumps(binary_media_types).replace("~1", "/")) self._doc[self._X_APIGW_BINARY_MEDIA_TYPES] = bmt - def _options_method_response_for_cors(self, allowed_origins, allowed_headers=None, allowed_methods=None, - max_age=None, allow_credentials=None): + def _options_method_response_for_cors( + self, allowed_origins, allowed_headers=None, allowed_methods=None, max_age=None, allow_credentials=None + ): """ Returns a Swagger snippet containing configuration for OPTIONS HTTP Method to configure CORS. @@ -303,7 +312,7 @@ def _options_method_response_for_cors(self, allowed_origins, allowed_headers=Non ALLOW_METHODS = "Access-Control-Allow-Methods" MAX_AGE = "Access-Control-Max-Age" ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials" - HEADER_RESPONSE = (lambda x: "method.response.header." + x) + HEADER_RESPONSE = lambda x: "method.response.header." + x response_parameters = { # AllowedOrigin is always required @@ -312,9 +321,7 @@ def _options_method_response_for_cors(self, allowed_origins, allowed_headers=Non response_headers = { # Allow Origin is always required - ALLOW_ORIGIN: { - "type": "string" - } + ALLOW_ORIGIN: {"type": "string"} } # Optional values. Skip the header if value is empty @@ -345,25 +352,16 @@ def _options_method_response_for_cors(self, allowed_origins, allowed_headers=Non "produces": ["application/json"], self._X_APIGW_INTEGRATION: { "type": "mock", - "requestTemplates": { - "application/json": "{\n \"statusCode\" : 200\n}\n" - }, + "requestTemplates": {"application/json": '{\n "statusCode" : 200\n}\n'}, "responses": { "default": { "statusCode": "200", "responseParameters": response_parameters, - "responseTemplates": { - "application/json": "{}\n" - } + "responseTemplates": {"application/json": "{}\n"}, } - } + }, }, - "responses": { - "200": { - "description": "Default response for CORS method", - "headers": response_headers - } - } + "responses": {"200": {"description": "Default response for CORS method", "headers": response_headers}}, } def _make_cors_allowed_methods_for_path(self, path): @@ -402,7 +400,7 @@ def _make_cors_allowed_methods_for_path(self, path): allow_methods.sort() # Allow-Methods is comma separated string - return ','.join(allow_methods) + return ",".join(allow_methods) def add_authorizers_security_definitions(self, authorizers): """ @@ -422,11 +420,11 @@ def add_awsiam_security_definition(self): """ aws_iam_security_definition = { - 'AWS_IAM': { - 'x-amazon-apigateway-authtype': 'awsSigv4', - 'type': 'apiKey', - 'name': 'Authorization', - 'in': 'header' + "AWS_IAM": { + "x-amazon-apigateway-authtype": "awsSigv4", + "type": "apiKey", + "name": "Authorization", + "in": "header", } } @@ -434,7 +432,7 @@ def add_awsiam_security_definition(self): # Only add the security definition if it doesn't exist. This helps ensure # that we minimize changes to the swagger in the case of user defined swagger - if 'AWS_IAM' not in self.security_definitions: + if "AWS_IAM" not in self.security_definitions: self.security_definitions.update(aws_iam_security_definition) def add_apikey_security_definition(self): @@ -443,23 +441,18 @@ def add_apikey_security_definition(self): Note: this method is idempotent """ - api_key_security_definition = { - 'api_key': { - "type": "apiKey", - "name": "x-api-key", - "in": "header" - } - } + api_key_security_definition = {"api_key": {"type": "apiKey", "name": "x-api-key", "in": "header"}} self.security_definitions = self.security_definitions or {} # Only add the security definition if it doesn't exist. This helps ensure # that we minimize changes to the swagger in the case of user defined swagger - if 'api_key' not in self.security_definitions: + if "api_key" not in self.security_definitions: self.security_definitions.update(api_key_security_definition) - def set_path_default_authorizer(self, path, default_authorizer, authorizers, - add_default_auth_to_preflight=True, api_authorizers=None): + def set_path_default_authorizer( + self, path, default_authorizer, authorizers, add_default_auth_to_preflight=True, api_authorizers=None + ): """ Adds the default_authorizer to the security block for each method on this path unless an Authorizer was defined at the Function/Path/Method level. This is intended to be used to set the @@ -488,8 +481,8 @@ def set_path_default_authorizer(self, path, default_authorizer, authorizers, # If no integration given, then we don't need to process this definition (could be AWS::NoValue) if not self.method_definition_has_integration(method_definition): continue - existing_security = method_definition.get('security', []) - authorizer_list = ['AWS_IAM'] + existing_security = method_definition.get("security", []) + authorizer_list = ["AWS_IAM"] if authorizers: authorizer_list.extend(authorizers.keys()) authorizer_names = set(authorizer_list) @@ -514,7 +507,7 @@ def set_path_default_authorizer(self, path, default_authorizer, authorizers, # applied (Function Api Events first; then Api Resource) complicates it. # Check if Function/Path/Method specified 'NONE' for Authorizer for idx, security in enumerate(existing_non_authorizer_security): - is_none = any(key == 'NONE' for key in security.keys()) + is_none = any(key == "NONE" for key in security.keys()) if is_none: none_idx = idx @@ -531,18 +524,19 @@ def set_path_default_authorizer(self, path, default_authorizer, authorizers, # No existing Authorizer found; use default else: security_dict = {} - security_dict[default_authorizer] = self._get_authorization_scopes(api_authorizers, - default_authorizer) + security_dict[default_authorizer] = self._get_authorization_scopes( + api_authorizers, default_authorizer + ) authorizer_security = [security_dict] security = existing_non_authorizer_security + authorizer_security if security: - method_definition['security'] = security + method_definition["security"] = security # The first element of the method_definition['security'] should be AWS_IAM # because authorizer_list = ['AWS_IAM'] is hardcoded above - if 'AWS_IAM' in method_definition['security'][0]: + if "AWS_IAM" in method_definition["security"][0]: self.add_awsiam_security_definition() def set_path_default_apikey_required(self, path): @@ -568,8 +562,8 @@ def set_path_default_apikey_required(self, path): if not self.method_definition_has_integration(method_definition): continue - existing_security = method_definition.get('security', []) - apikey_security_names = set(['api_key', 'api_key_false']) + existing_security = method_definition.get("security", []) + apikey_security_names = set(["api_key", "api_key_false"]) existing_non_apikey_security = [] existing_apikey_security = [] apikey_security = [] @@ -590,7 +584,7 @@ def set_path_default_apikey_required(self, path): # Check if Function/Path/Method specified 'False' for ApiKeyRequired apikeyfalse_idx = -1 for idx, security in enumerate(existing_apikey_security): - is_none = any(key == 'api_key_false' for key in security.keys()) + is_none = any(key == "api_key_false" for key in security.keys()) if is_none: apikeyfalse_idx = idx @@ -603,13 +597,13 @@ def set_path_default_apikey_required(self, path): # No existing ApiKey setting found or it's already set to the default else: security_dict = {} - security_dict['api_key'] = [] + security_dict["api_key"] = [] apikey_security = [security_dict] security = existing_non_apikey_security + apikey_security if security != existing_security: - method_definition['security'] = security + method_definition["security"] = security def add_auth_to_method(self, path, method_name, auth, api): """ @@ -622,14 +616,14 @@ def add_auth_to_method(self, path, method_name, auth, api): :param dict auth: Auth configuration such as Authorizers, ApiKeyRequired, ResourcePolicy :param dict api: Reference to the related Api's properties as defined in the template. """ - method_authorizer = auth and auth.get('Authorizer') - method_scopes = auth and auth.get('AuthorizationScopes') - api_auth = api and api.get('Auth') - authorizers = api_auth and api_auth.get('Authorizers') + method_authorizer = auth and auth.get("Authorizer") + method_scopes = auth and auth.get("AuthorizationScopes") + api_auth = api and api.get("Auth") + authorizers = api_auth and api_auth.get("Authorizers") if method_authorizer: self._set_method_authorizer(path, method_name, method_authorizer, authorizers, method_scopes) - method_apikey_required = auth and auth.get('ApiKeyRequired') + method_apikey_required = auth and auth.get("ApiKeyRequired") if method_apikey_required is not None: self._set_method_apikey_handling(path, method_name, method_apikey_required) @@ -651,7 +645,7 @@ def _set_method_authorizer(self, path, method_name, authorizer_name, authorizers if not self.method_definition_has_integration(method_definition): continue - existing_security = method_definition.get('security', []) + existing_security = method_definition.get("security", []) security_dict = {} security_dict[authorizer_name] = [] @@ -660,7 +654,7 @@ def _set_method_authorizer(self, path, method_name, authorizer_name, authorizers # This assumes there are no autorizers already configured in the existing security block security = existing_security + authorizer_security - if authorizer_name != 'NONE' and authorizers: + if authorizer_name != "NONE" and authorizers: method_auth_scopes = authorizers.get(authorizer_name, {}).get("AuthorizationScopes") if method_scopes is not None: method_auth_scopes = method_scopes @@ -668,11 +662,11 @@ def _set_method_authorizer(self, path, method_name, authorizer_name, authorizers security_dict[authorizer_name] = method_auth_scopes if security: - method_definition['security'] = security + method_definition["security"] = security # The first element of the method_definition['security'] should be AWS_IAM # because authorizer_list = ['AWS_IAM'] is hardcoded above - if 'AWS_IAM' in method_definition['security'][0]: + if "AWS_IAM" in method_definition["security"][0]: self.add_awsiam_security_definition() def _set_method_apikey_handling(self, path, method_name, apikey_required): @@ -692,12 +686,12 @@ def _set_method_apikey_handling(self, path, method_name, apikey_required): if not self.method_definition_has_integration(method_definition): continue - existing_security = method_definition.get('security', []) + existing_security = method_definition.get("security", []) if apikey_required: # We want to enable apikey required security security_dict = {} - security_dict['api_key'] = [] + security_dict["api_key"] = [] apikey_security = [security_dict] self.add_apikey_security_definition() else: @@ -705,14 +699,14 @@ def _set_method_apikey_handling(self, path, method_name, apikey_required): # so let's add a marker 'api_key_false' so that we don't incorrectly override # with the api default security_dict = {} - security_dict['api_key_false'] = [] + security_dict["api_key_false"] = [] apikey_security = [security_dict] # This assumes there are no autorizers already configured in the existing security block security = existing_security + apikey_security if security != existing_security: - method_definition['security'] = security + method_definition["security"] = security def add_request_model_to_method(self, path, method_name, request_model): """ @@ -722,8 +716,8 @@ def add_request_model_to_method(self, path, method_name, request_model): :param string method_name: Method name :param dict request_model: Model name """ - model_name = request_model and request_model.get('Model').lower() - model_required = request_model and request_model.get('Required') + model_name = request_model and request_model.get("Model").lower() + model_required = request_model and request_model.get("Required") normalized_method_name = self._normalize_method_name(method_name) # It is possible that the method could have two definitions in a Fn::If block. @@ -733,41 +727,32 @@ def add_request_model_to_method(self, path, method_name, request_model): if not self.method_definition_has_integration(method_definition): continue - if self._doc.get('swagger') is not None: + if self._doc.get("swagger") is not None: - existing_parameters = method_definition.get('parameters', []) + existing_parameters = method_definition.get("parameters", []) parameter = { - 'in': 'body', - 'name': model_name, - 'schema': { - '$ref': '#/definitions/{}'.format(model_name) - } + "in": "body", + "name": model_name, + "schema": {"$ref": "#/definitions/{}".format(model_name)}, } if model_required is not None: - parameter['required'] = model_required + parameter["required"] = model_required existing_parameters.append(parameter) - method_definition['parameters'] = existing_parameters - - elif self._doc.get("openapi") and \ - SwaggerEditor.safe_compare_regex_with_string( - SwaggerEditor.get_openapi_version_3_regex(), self._doc["openapi"]): - method_definition['requestBody'] = { - 'content': { - "application/json": { - "schema": { - "$ref": "#/components/schemas/{}".format(model_name) - } - } + method_definition["parameters"] = existing_parameters - } + elif self._doc.get("openapi") and SwaggerEditor.safe_compare_regex_with_string( + SwaggerEditor.get_openapi_version_3_regex(), self._doc["openapi"] + ): + method_definition["requestBody"] = { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/{}".format(model_name)}}} } if model_required is not None: - method_definition['requestBody']['required'] = model_required + method_definition["requestBody"]["required"] = model_required def add_gateway_responses(self, gateway_responses): """ @@ -792,8 +777,8 @@ def add_models(self, models): for model_name, schema in models.items(): - model_type = schema.get('type') - model_properties = schema.get('properties') + model_type = schema.get("type") + model_properties = schema.get("properties") if not model_type: raise ValueError("Invalid input. Value for type is required") @@ -813,13 +798,13 @@ def add_resource_policy(self, resource_policy, path, api_id, stage): if resource_policy is None: return - aws_account_whitelist = resource_policy.get('AwsAccountWhitelist') - aws_account_blacklist = resource_policy.get('AwsAccountBlacklist') - ip_range_whitelist = resource_policy.get('IpRangeWhitelist') - ip_range_blacklist = resource_policy.get('IpRangeBlacklist') - source_vpc_whitelist = resource_policy.get('SourceVpcWhitelist') - source_vpc_blacklist = resource_policy.get('SourceVpcBlacklist') - custom_statements = resource_policy.get('CustomStatements') + aws_account_whitelist = resource_policy.get("AwsAccountWhitelist") + aws_account_blacklist = resource_policy.get("AwsAccountBlacklist") + ip_range_whitelist = resource_policy.get("IpRangeWhitelist") + ip_range_blacklist = resource_policy.get("IpRangeBlacklist") + source_vpc_whitelist = resource_policy.get("SourceVpcWhitelist") + source_vpc_blacklist = resource_policy.get("SourceVpcBlacklist") + custom_statements = resource_policy.get("CustomStatements") if aws_account_whitelist is not None: resource_list = self._get_method_path_uri_list(path, api_id, stage) @@ -862,31 +847,31 @@ def _add_iam_resource_policy_for_method(self, policy_list, effect, resource_list return if effect not in ["Allow", "Deny"]: - raise ValueError('Effect must be one of {}'.format(['Allow', 'Deny'])) + raise ValueError("Effect must be one of {}".format(["Allow", "Deny"])) if not isinstance(policy_list, (dict, list)): raise InvalidDocumentException( - [InvalidTemplateException("Type of '{}' must be a list or dictionary" - .format(policy_list))]) + [InvalidTemplateException("Type of '{}' must be a list or dictionary".format(policy_list))] + ) if not isinstance(policy_list, list): policy_list = [policy_list] - self.resource_policy['Version'] = '2012-10-17' + self.resource_policy["Version"] = "2012-10-17" policy_statement = {} - policy_statement['Effect'] = effect - policy_statement['Action'] = "execute-api:Invoke" - policy_statement['Resource'] = resource_list - policy_statement['Principal'] = {"AWS": policy_list} + policy_statement["Effect"] = effect + policy_statement["Action"] = "execute-api:Invoke" + policy_statement["Resource"] = resource_list + policy_statement["Principal"] = {"AWS": policy_list} - if self.resource_policy.get('Statement') is None: - self.resource_policy['Statement'] = policy_statement + if self.resource_policy.get("Statement") is None: + self.resource_policy["Statement"] = policy_statement else: - statement = self.resource_policy['Statement'] + statement = self.resource_policy["Statement"] if not isinstance(statement, list): statement = [statement] statement.extend([policy_statement]) - self.resource_policy['Statement'] = statement + self.resource_policy["Statement"] = statement def _get_method_path_uri_list(self, path, api_id, stage): """ @@ -900,7 +885,7 @@ def _get_method_path_uri_list(self, path, api_id, stage): path = SwaggerEditor.get_path_without_trailing_slash(path) for m in methods: - method = '*' if (m.lower() == self._X_ANY_METHOD or m.lower() == 'any') else m.upper() + method = "*" if (m.lower() == self._X_ANY_METHOD or m.lower() == "any") else m.upper() resource = "execute-api:/${__Stage__}/" + method + path resource = fnSub(resource, {"__Stage__": stage}) uri_list.extend([resource]) @@ -919,33 +904,33 @@ def _add_ip_resource_policy_for_method(self, ip_list, conditional, resource_list ip_list = [ip_list] if conditional not in ["IpAddress", "NotIpAddress"]: - raise ValueError('Conditional must be one of {}'.format(["IpAddress", "NotIpAddress"])) + raise ValueError("Conditional must be one of {}".format(["IpAddress", "NotIpAddress"])) - self.resource_policy['Version'] = '2012-10-17' + self.resource_policy["Version"] = "2012-10-17" allow_statement = {} - allow_statement['Effect'] = "Allow" - allow_statement['Action'] = "execute-api:Invoke" - allow_statement['Resource'] = resource_list - allow_statement['Principal'] = "*" + allow_statement["Effect"] = "Allow" + allow_statement["Action"] = "execute-api:Invoke" + allow_statement["Resource"] = resource_list + allow_statement["Principal"] = "*" deny_statement = {} - deny_statement['Effect'] = "Deny" - deny_statement['Action'] = "execute-api:Invoke" - deny_statement['Resource'] = resource_list - deny_statement['Principal'] = "*" - deny_statement['Condition'] = {conditional: {"aws:SourceIp": ip_list}} - - if self.resource_policy.get('Statement') is None: - self.resource_policy['Statement'] = [allow_statement, deny_statement] + deny_statement["Effect"] = "Deny" + deny_statement["Action"] = "execute-api:Invoke" + deny_statement["Resource"] = resource_list + deny_statement["Principal"] = "*" + deny_statement["Condition"] = {conditional: {"aws:SourceIp": ip_list}} + + if self.resource_policy.get("Statement") is None: + self.resource_policy["Statement"] = [allow_statement, deny_statement] else: - statement = self.resource_policy['Statement'] + statement = self.resource_policy["Statement"] if not isinstance(statement, list): statement = [statement] if allow_statement not in statement: statement.extend([allow_statement]) if deny_statement not in statement: statement.extend([deny_statement]) - self.resource_policy['Statement'] = statement + self.resource_policy["Statement"] = statement def _add_vpc_resource_policy_for_method(self, vpc, conditional, resource_list): """ @@ -957,7 +942,7 @@ def _add_vpc_resource_policy_for_method(self, vpc, conditional, resource_list): return if conditional not in ["StringNotEquals", "StringEquals"]: - raise ValueError('Conditional must be one of {}'.format(["StringNotEquals", "StringEquals"])) + raise ValueError("Conditional must be one of {}".format(["StringNotEquals", "StringEquals"])) vpce_regex = r"^vpce-" if not re.match(vpce_regex, vpc): @@ -965,31 +950,31 @@ def _add_vpc_resource_policy_for_method(self, vpc, conditional, resource_list): else: endpoint = "aws:SourceVpce" - self.resource_policy['Version'] = '2012-10-17' + self.resource_policy["Version"] = "2012-10-17" allow_statement = {} - allow_statement['Effect'] = "Allow" - allow_statement['Action'] = "execute-api:Invoke" - allow_statement['Resource'] = resource_list - allow_statement['Principal'] = "*" + allow_statement["Effect"] = "Allow" + allow_statement["Action"] = "execute-api:Invoke" + allow_statement["Resource"] = resource_list + allow_statement["Principal"] = "*" deny_statement = {} - deny_statement['Effect'] = "Deny" - deny_statement['Action'] = "execute-api:Invoke" - deny_statement['Resource'] = resource_list - deny_statement['Principal'] = "*" - deny_statement['Condition'] = {conditional: {endpoint: vpc}} - - if self.resource_policy.get('Statement') is None: - self.resource_policy['Statement'] = [allow_statement, deny_statement] + deny_statement["Effect"] = "Deny" + deny_statement["Action"] = "execute-api:Invoke" + deny_statement["Resource"] = resource_list + deny_statement["Principal"] = "*" + deny_statement["Condition"] = {conditional: {endpoint: vpc}} + + if self.resource_policy.get("Statement") is None: + self.resource_policy["Statement"] = [allow_statement, deny_statement] else: - statement = self.resource_policy['Statement'] + statement = self.resource_policy["Statement"] if not isinstance(statement, list): statement = [statement] if allow_statement not in statement: statement.extend([allow_statement]) if deny_statement not in statement: statement.extend([deny_statement]) - self.resource_policy['Statement'] = statement + self.resource_policy["Statement"] = statement def _add_custom_statement(self, custom_statements): if custom_statements is None: @@ -998,17 +983,17 @@ def _add_custom_statement(self, custom_statements): if not isinstance(custom_statements, list): custom_statements = [custom_statements] - self.resource_policy['Version'] = '2012-10-17' - if self.resource_policy.get('Statement') is None: - self.resource_policy['Statement'] = custom_statements + self.resource_policy["Version"] = "2012-10-17" + if self.resource_policy.get("Statement") is None: + self.resource_policy["Statement"] = custom_statements else: - statement = self.resource_policy['Statement'] + statement = self.resource_policy["Statement"] if not isinstance(statement, list): statement = [statement] for s in custom_statements: if s not in statement: statement.append(s) - self.resource_policy['Statement'] = statement + self.resource_policy["Statement"] = statement def add_request_parameters_to_method(self, path, method_name, request_parameters): """ @@ -1028,34 +1013,29 @@ def add_request_parameters_to_method(self, path, method_name, request_parameters if not self.method_definition_has_integration(method_definition): continue - existing_parameters = method_definition.get('parameters', []) + existing_parameters = method_definition.get("parameters", []) for request_parameter in request_parameters: - parameter_name = request_parameter['Name'] - location_name = parameter_name.replace('method.request.', '') - location, name = location_name.split('.') + parameter_name = request_parameter["Name"] + location_name = parameter_name.replace("method.request.", "") + location, name = location_name.split(".") - if location == 'querystring': - location = 'query' + if location == "querystring": + location = "query" - parameter = { - 'in': location, - 'name': name, - 'required': request_parameter['Required'], - 'type': 'string' - } + parameter = {"in": location, "name": name, "required": request_parameter["Required"], "type": "string"} existing_parameters.append(parameter) - if request_parameter['Caching']: + if request_parameter["Caching"]: integration = method_definition[self._X_APIGW_INTEGRATION] cache_parameters = integration.get(self._CACHE_KEY_PARAMETERS, []) cache_parameters.append(parameter_name) integration[self._CACHE_KEY_PARAMETERS] = cache_parameters - method_definition['parameters'] = existing_parameters + method_definition["parameters"] = existing_parameters @property def swagger(self): @@ -1073,7 +1053,7 @@ def swagger(self): if self.gateway_responses: self._doc[self._X_APIGW_GATEWAY_RESPONSES] = self.gateway_responses if self.definitions: - self._doc['definitions'] = self.definitions + self._doc["definitions"] = self.definitions return copy.deepcopy(self._doc) @@ -1086,12 +1066,13 @@ def is_valid(data): :return: True, if data is a Swagger """ - if bool(data) and isinstance(data, dict) and isinstance(data.get('paths'), dict): + if bool(data) and isinstance(data, dict) and isinstance(data.get("paths"), dict): if bool(data.get("swagger")): return True elif bool(data.get("openapi")): return SwaggerEditor.safe_compare_regex_with_string( - SwaggerEditor.get_openapi_version_3_regex(), data["openapi"]) + SwaggerEditor.get_openapi_version_3_regex(), data["openapi"] + ) return False @staticmethod @@ -1101,15 +1082,7 @@ def gen_skeleton(): :return dict: Dictionary of a skeleton swagger document """ - return { - 'swagger': '2.0', - 'info': { - 'version': '1.0', - 'title': ref('AWS::StackName') - }, - 'paths': { - } - } + return {"swagger": "2.0", "info": {"version": "1.0", "title": ref("AWS::StackName")}, "paths": {}} @staticmethod def _get_authorization_scopes(authorizers, default_authorizer): @@ -1119,8 +1092,10 @@ def _get_authorization_scopes(authorizers, default_authorizer): :param default_authorizer: name of the default authorizer """ if authorizers is not None: - if authorizers.get(default_authorizer) \ - and authorizers[default_authorizer].get("AuthorizationScopes") is not None: + if ( + authorizers.get(default_authorizer) + and authorizers[default_authorizer].get("AuthorizationScopes") is not None + ): return authorizers[default_authorizer].get("AuthorizationScopes") return [] @@ -1139,7 +1114,7 @@ def _normalize_method_name(method): return method method = method.lower() - if method == 'any': + if method == "any": return SwaggerEditor._X_ANY_METHOD else: return method @@ -1160,4 +1135,4 @@ def safe_compare_regex_with_string(regex, data): @staticmethod def get_path_without_trailing_slash(path): - return re.sub(r'{([a-zA-Z0-9._-]+|proxy\+)}', '*', path) + return re.sub(r"{([a-zA-Z0-9._-]+|proxy\+)}", "*", path) diff --git a/samtranslator/translator/arn_generator.py b/samtranslator/translator/arn_generator.py index f15ca0e333..633a7ba7a6 100644 --- a/samtranslator/translator/arn_generator.py +++ b/samtranslator/translator/arn_generator.py @@ -2,18 +2,17 @@ class ArnGenerator(object): - @classmethod def generate_arn(cls, partition, service, resource, include_account_id=True): if not service or not resource: raise RuntimeError("Could not construct ARN for resource.") - arn = 'arn:{0}:{1}:${{AWS::Region}}:' + arn = "arn:{0}:{1}:${{AWS::Region}}:" if include_account_id: - arn += '${{AWS::AccountId}}:' + arn += "${{AWS::AccountId}}:" - arn += '{2}' + arn += "{2}" return arn.format(partition, service, resource) @@ -26,8 +25,7 @@ def generate_aws_managed_policy_arn(cls, policy_name): :param policy_name: Name of the policy :return: ARN Of the managed policy """ - return 'arn:{}:iam::aws:policy/{}'.format(ArnGenerator.get_partition_name(), - policy_name) + return "arn:{}:iam::aws:policy/{}".format(ArnGenerator.get_partition_name(), policy_name) @classmethod def get_partition_name(cls, region=None): diff --git a/samtranslator/translator/logical_id_generator.py b/samtranslator/translator/logical_id_generator.py index 40d87c98f6..bbebdb7376 100644 --- a/samtranslator/translator/logical_id_generator.py +++ b/samtranslator/translator/logical_id_generator.py @@ -62,10 +62,10 @@ def get_hash(self, length=HASH_LENGTH): if sys.version_info.major == 2: # In Py2, only unicode needs to be encoded. if isinstance(self.data_str, unicode): - encoded_data_str = self.data_str.encode('utf-8') + encoded_data_str = self.data_str.encode("utf-8") else: # data_str should always be unicode on python 3 - encoded_data_str = self.data_str.encode('utf-8') + encoded_data_str = self.data_str.encode("utf-8") data_hash = hashlib.sha1(encoded_data_str).hexdigest() @@ -87,4 +87,4 @@ def _stringify(self, data): return data # Get the most compact dictionary (separators) and sort the keys recursively to get a stable output - return json.dumps(data, separators=(',', ':'), sort_keys=True) + return json.dumps(data, separators=(",", ":"), sort_keys=True) diff --git a/samtranslator/translator/managed_policy_translator.py b/samtranslator/translator/managed_policy_translator.py index 53031664b1..8621ec21d1 100644 --- a/samtranslator/translator/managed_policy_translator.py +++ b/samtranslator/translator/managed_policy_translator.py @@ -1,20 +1,19 @@ class ManagedPolicyLoader(object): - def __init__(self, iam_client): self._iam_client = iam_client self._policy_map = None def load(self): if self._policy_map is None: - paginator = self._iam_client.get_paginator('list_policies') + paginator = self._iam_client.get_paginator("list_policies") # Setting the scope to AWS limits the returned values to only AWS Managed Policies and will # not returned policies owned by any specific account. # http://docs.aws.amazon.com/IAM/latest/APIReference/API_ListPolicies.html#API_ListPolicies_RequestParameters - page_iterator = paginator.paginate(Scope='AWS') + page_iterator = paginator.paginate(Scope="AWS") name_to_arn_map = {} for page in page_iterator: - name_to_arn_map.update(map(lambda x: (x['PolicyName'], x['Arn']), page['Policies'])) + name_to_arn_map.update(map(lambda x: (x["PolicyName"], x["Arn"]), page["Policies"])) self._policy_map = name_to_arn_map return self._policy_map diff --git a/samtranslator/translator/translator.py b/samtranslator/translator/translator.py index 379b5559d0..c47e1095cc 100644 --- a/samtranslator/translator/translator.py +++ b/samtranslator/translator/translator.py @@ -2,8 +2,12 @@ from samtranslator.model import ResourceTypeResolver, sam_resources from samtranslator.translator.verify_logical_id import verify_unique_logical_id from samtranslator.model.preferences.deployment_preference_collection import DeploymentPreferenceCollection -from samtranslator.model.exceptions import (InvalidDocumentException, InvalidResourceException, - DuplicateLogicalIdException, InvalidEventException) +from samtranslator.model.exceptions import ( + InvalidDocumentException, + InvalidResourceException, + DuplicateLogicalIdException, + InvalidEventException, +) from samtranslator.intrinsics.resolver import IntrinsicsResolver from samtranslator.intrinsics.actions import FindInMapAction from samtranslator.intrinsics.resource_refs import SupportedResourceReferences @@ -20,6 +24,7 @@ class Translator: """Translates SAM templates into CloudFormation templates """ + def __init__(self, managed_policy_map, sam_parser, plugins=None): """ :param dict managed_policy_map: Map of managed policy names to the ARNs @@ -53,17 +58,14 @@ def translate(self, sam_template, parameter_values): # Create & Install plugins sam_plugins = prepare_plugins(self.plugins, parameter_values) - self.sam_parser.parse( - sam_template=sam_template, - parameter_values=parameter_values, - sam_plugins=sam_plugins - ) + self.sam_parser.parse(sam_template=sam_template, parameter_values=parameter_values, sam_plugins=sam_plugins) template = copy.deepcopy(sam_template) macro_resolver = ResourceTypeResolver(sam_resources) intrinsics_resolver = IntrinsicsResolver(parameter_values) - mappings_resolver = IntrinsicsResolver(template.get('Mappings', {}), - {FindInMapAction.intrinsic_name: FindInMapAction()}) + mappings_resolver = IntrinsicsResolver( + template.get("Mappings", {}), {FindInMapAction.intrinsic_name: FindInMapAction()} + ) deployment_preference_collection = DeploymentPreferenceCollection() supported_resource_refs = SupportedResourceReferences() document_errors = [] @@ -71,16 +73,16 @@ def translate(self, sam_template, parameter_values): for logical_id, resource_dict in self._get_resources_to_iterate(sam_template, macro_resolver): try: - macro = macro_resolver\ - .resolve_resource_type(resource_dict)\ - .from_dict(logical_id, resource_dict, sam_plugins=sam_plugins) + macro = macro_resolver.resolve_resource_type(resource_dict).from_dict( + logical_id, resource_dict, sam_plugins=sam_plugins + ) - kwargs = macro.resources_to_link(sam_template['Resources']) - kwargs['managed_policy_map'] = self.managed_policy_map - kwargs['intrinsics_resolver'] = intrinsics_resolver - kwargs['mappings_resolver'] = mappings_resolver - kwargs['deployment_preference_collection'] = deployment_preference_collection - kwargs['conditions'] = template.get('Conditions') + kwargs = macro.resources_to_link(sam_template["Resources"]) + kwargs["managed_policy_map"] = self.managed_policy_map + kwargs["intrinsics_resolver"] = intrinsics_resolver + kwargs["mappings_resolver"] = mappings_resolver + kwargs["deployment_preference_collection"] = deployment_preference_collection + kwargs["conditions"] = template.get("Conditions") translated = macro.to_cloudformation(**kwargs) @@ -90,24 +92,25 @@ def translate(self, sam_template, parameter_values): if logical_id != macro.logical_id: changed_logical_ids[logical_id] = macro.logical_id - del template['Resources'][logical_id] + del template["Resources"][logical_id] for resource in translated: - if verify_unique_logical_id(resource, sam_template['Resources']): - template['Resources'].update(resource.to_dict()) + if verify_unique_logical_id(resource, sam_template["Resources"]): + template["Resources"].update(resource.to_dict()) else: - document_errors.append(DuplicateLogicalIdException( - logical_id, resource.logical_id, resource.resource_type)) + document_errors.append( + DuplicateLogicalIdException(logical_id, resource.logical_id, resource.resource_type) + ) except (InvalidResourceException, InvalidEventException) as e: document_errors.append(e) if deployment_preference_collection.any_enabled(): - template['Resources'].update(deployment_preference_collection.codedeploy_application.to_dict()) + template["Resources"].update(deployment_preference_collection.codedeploy_application.to_dict()) if not deployment_preference_collection.can_skip_service_role(): - template['Resources'].update(deployment_preference_collection.codedeploy_iam_role.to_dict()) + template["Resources"].update(deployment_preference_collection.codedeploy_iam_role.to_dict()) for logical_id in deployment_preference_collection.enabled_logical_ids(): - template['Resources'].update(deployment_preference_collection.deployment_group(logical_id).to_dict()) + template["Resources"].update(deployment_preference_collection.deployment_group(logical_id).to_dict()) # Run the after-transform plugin target try: @@ -116,8 +119,8 @@ def translate(self, sam_template, parameter_values): document_errors.append(e) # Cleanup - if 'Transform' in template: - del template['Transform'] + if "Transform" in template: + del template["Transform"] if len(document_errors) == 0: template = intrinsics_resolver.resolve_sam_resource_id_refs(template, changed_logical_ids) @@ -197,12 +200,14 @@ def prepare_plugins(plugins, parameters={}): def make_implicit_rest_api_plugin(): # This is necessary to prevent a circular dependency on imports when loading package from samtranslator.plugins.api.implicit_rest_api_plugin import ImplicitRestApiPlugin + return ImplicitRestApiPlugin() def make_implicit_http_api_plugin(): # This is necessary to prevent a circular dependency on imports when loading package from samtranslator.plugins.api.implicit_http_api_plugin import ImplicitHttpApiPlugin + return ImplicitHttpApiPlugin() diff --git a/samtranslator/translator/verify_logical_id.py b/samtranslator/translator/verify_logical_id.py index 8557a4a3f4..da8b726f70 100644 --- a/samtranslator/translator/verify_logical_id.py +++ b/samtranslator/translator/verify_logical_id.py @@ -1,16 +1,16 @@ do_not_verify = { # type_after_transform: type_before_transform - 'AWS::Lambda::Function': 'AWS::Serverless::Function', - 'AWS::Lambda::LayerVersion': 'AWS::Serverless::LayerVersion', - 'AWS::ApiGateway::RestApi': 'AWS::Serverless::Api', - 'AWS::ApiGatewayV2::Api': 'AWS::Serverless::HttpApi', - 'AWS::S3::Bucket': 'AWS::S3::Bucket', - 'AWS::SNS::Topic': 'AWS::SNS::Topic', - 'AWS::DynamoDB::Table': 'AWS::Serverless::SimpleTable', - 'AWS::CloudFormation::Stack': 'AWS::Serverless::Application', - 'AWS::Cognito::UserPool': 'AWS::Cognito::UserPool', - 'AWS::ApiGateway::DomainName': 'AWS::ApiGateway::DomainName', - 'AWS::ApiGateway::BasePathMapping': 'AWS::ApiGateway::BasePathMapping' + "AWS::Lambda::Function": "AWS::Serverless::Function", + "AWS::Lambda::LayerVersion": "AWS::Serverless::LayerVersion", + "AWS::ApiGateway::RestApi": "AWS::Serverless::Api", + "AWS::ApiGatewayV2::Api": "AWS::Serverless::HttpApi", + "AWS::S3::Bucket": "AWS::S3::Bucket", + "AWS::SNS::Topic": "AWS::SNS::Topic", + "AWS::DynamoDB::Table": "AWS::Serverless::SimpleTable", + "AWS::CloudFormation::Stack": "AWS::Serverless::Application", + "AWS::Cognito::UserPool": "AWS::Cognito::UserPool", + "AWS::ApiGateway::DomainName": "AWS::ApiGateway::DomainName", + "AWS::ApiGateway::BasePathMapping": "AWS::ApiGateway::BasePathMapping", } @@ -18,7 +18,9 @@ def verify_unique_logical_id(resource, existing_resources): # new resource logicalid exists in the template before transform if resource.logical_id is not None and resource.logical_id in existing_resources: # new resource logicalid is in the do_not_resolve list - if resource.resource_type not in do_not_verify or existing_resources[resource.logical_id]['Type'] \ - not in do_not_verify[resource.resource_type]: + if ( + resource.resource_type not in do_not_verify + or existing_resources[resource.logical_id]["Type"] not in do_not_verify[resource.resource_type] + ): return False return True diff --git a/samtranslator/validator/validator.py b/samtranslator/validator/validator.py index 14d36cbd23..97c5257cc7 100644 --- a/samtranslator/validator/validator.py +++ b/samtranslator/validator/validator.py @@ -7,7 +7,6 @@ class SamTemplateValidator(object): - @staticmethod def validate(template_dict, schema=None): """ diff --git a/samtranslator/yaml_helper.py b/samtranslator/yaml_helper.py index 67c958b4d3..67fc33f09b 100644 --- a/samtranslator/yaml_helper.py +++ b/samtranslator/yaml_helper.py @@ -8,8 +8,7 @@ def yaml_parse(yamlstr): """Parse a yaml string""" - yaml.SafeLoader.add_multi_constructor( - "!", intrinsics_multi_constructor) + yaml.SafeLoader.add_multi_constructor("!", intrinsics_multi_constructor) return yaml.safe_load(yamlstr) diff --git a/setup.py b/setup.py index e90fe1be9e..24c6ff9759 100755 --- a/setup.py +++ b/setup.py @@ -28,8 +28,8 @@ def read(*filenames, **kwargs): - encoding = kwargs.get('encoding', 'utf-8') - sep = kwargs.get('sep', os.linesep) + encoding = kwargs.get("encoding", "utf-8") + sep = kwargs.get("sep", os.linesep) buf = [] for filename in filenames: with io.open(filename, encoding=encoding) as f: @@ -38,50 +38,46 @@ def read(*filenames, **kwargs): def read_version(): - content = read(os.path.join( - os.path.dirname(__file__), 'samtranslator', '__init__.py')) - return re.search(r"__version__ = '([^']+)'", content).group(1) + content = read(os.path.join(os.path.dirname(__file__), "samtranslator", "__init__.py")) + return re.search(r"__version__ = \"([^']+)\"", content).group(1) -def read_requirements(req='base.txt'): - content = read(os.path.join('requirements', req)) - return [line for line in content.split(os.linesep) - if not line.strip().startswith('#')] +def read_requirements(req="base.txt"): + content = read(os.path.join("requirements", req)) + return [line for line in content.split(os.linesep) if not line.strip().startswith("#")] setup( - name='aws-sam-translator', + name="aws-sam-translator", version=read_version(), - description='AWS SAM Translator is a library that transform SAM templates into AWS CloudFormation templates', - long_description=read('README.md'), - long_description_content_type='text/markdown', - author='Amazon Web Services', - author_email='aws-sam-developers@amazon.com', - url='https://github.com/awslabs/serverless-application-model', - license='Apache License 2.0', + description="AWS SAM Translator is a library that transform SAM templates into AWS CloudFormation templates", + long_description=read("README.md"), + long_description_content_type="text/markdown", + author="Amazon Web Services", + author_email="aws-sam-developers@amazon.com", + url="https://github.com/awslabs/serverless-application-model", + license="Apache License 2.0", # Exclude all but the code folders - packages=find_packages(exclude=('tests*', 'docs', 'examples', 'versions')), - install_requires=read_requirements('base.txt'), + packages=find_packages(exclude=("tests", "docs", "examples", "versions")), + install_requires=read_requirements("base.txt"), include_package_data=True, - extras_require={ - 'dev': read_requirements('dev.txt') - }, + extras_require={"dev": read_requirements("dev.txt")}, keywords="AWS SAM Serverless Application Model", classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Environment :: Other Environment', - 'Intended Audience :: Developers', - 'Intended Audience :: Information Technology', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Topic :: Internet', - 'Topic :: Software Development :: Build Tools', - 'Topic :: Utilities' - ] + "Development Status :: 4 - Beta", + "Environment :: Console", + "Environment :: Other Environment", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Internet", + "Topic :: Software Development :: Build Tools", + "Topic :: Utilities", + ], ) diff --git a/tests/README b/tests/README deleted file mode 100644 index 5bf37b7190..0000000000 --- a/tests/README +++ /dev/null @@ -1 +0,0 @@ -See [Development Guide](../DEVELOPMENT_GUIDE.rst) for details on how to setup and run the tests diff --git a/tests/intrinsics/test_actions.py b/tests/intrinsics/test_actions.py index 81dad2abe2..8119a71b4f 100644 --- a/tests/intrinsics/test_actions.py +++ b/tests/intrinsics/test_actions.py @@ -4,8 +4,8 @@ from samtranslator.intrinsics.resource_refs import SupportedResourceReferences from samtranslator.model.exceptions import InvalidTemplateException, InvalidDocumentException -class TestAction(TestCase): +class TestAction(TestCase): def test_subclass_must_override_type(self): # Subclass must override the intrinsic_name @@ -32,9 +32,7 @@ def test_can_handle_input(self): class MyAction(Action): intrinsic_name = "foo" - input = { - "foo": ["something"] - } + input = {"foo": ["something"]} action = MyAction() self.assertTrue(action.can_handle(input)) @@ -43,9 +41,7 @@ def test_can_handle_invalid_type(self): class MyAction(Action): intrinsic_name = "foo" - input = { - "bar": "something" - } + input = {"bar": "something"} action = MyAction() self.assertFalse(action.can_handle(input)) @@ -62,10 +58,7 @@ class MyAction(Action): intrinsic_name = "foo" # Intrinsic functions can be only in dict of length 1 - input = { - "foo": "some value", - "bar": "some other value" - } + input = {"foo": "some value", "bar": "some other value"} action = MyAction() self.assertFalse(action.can_handle(input)) @@ -122,66 +115,43 @@ def test_parse_resource_references_not_string(self): self.assertEqual(expected, Action._parse_resource_reference(input)) -class TestRefCanResolveParameterRefs(TestCase): +class TestRefCanResolveParameterRefs(TestCase): def test_can_resolve_ref(self): - parameters = { - "key": "value" - } - input = { - "Ref": "key" - } + parameters = {"key": "value"} + input = {"Ref": "key"} ref = RefAction() self.assertEqual(parameters["key"], ref.resolve_parameter_refs(input, parameters)) def test_unknown_ref(self): - parameters = { - "key": "value" - } - input = { - "Ref": "someotherkey" - } - expected = { - "Ref": "someotherkey" - } + parameters = {"key": "value"} + input = {"Ref": "someotherkey"} + expected = {"Ref": "someotherkey"} ref = RefAction() self.assertEqual(expected, ref.resolve_parameter_refs(input, parameters)) def test_must_ignore_invalid_value(self): - parameters = { - "key": "value" - } - input = { - "Ref": ["invalid value"] - } - expected = { - "Ref": ["invalid value"] - } + parameters = {"key": "value"} + input = {"Ref": ["invalid value"]} + expected = {"Ref": ["invalid value"]} ref = RefAction() self.assertEqual(expected, ref.resolve_parameter_refs(input, parameters)) @patch.object(RefAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - parameters = { - "key": "value" - } - input = { - "Ref": "key" - } - expected = { - "Ref": "key" - } + parameters = {"key": "value"} + input = {"Ref": "key"} + expected = {"Ref": "key"} ref = RefAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, ref.resolve_parameter_refs(input, parameters)) class TestRefCanResolveResourceRefs(TestCase): - def setUp(self): self.supported_resource_refs_mock = Mock() self.ref = RefAction() @@ -189,12 +159,8 @@ def setUp(self): @patch.object(RefAction, "_parse_resource_reference") def test_must_replace_refs(self, _parse_resource_reference_mock): resolved_value = "SomeOtherValue" - input = { - "Ref": "LogicalId.Property" - } - expected = { - "Ref": resolved_value - } + input = {"Ref": "LogicalId.Property"} + expected = {"Ref": resolved_value} _parse_resource_reference_mock.return_value = ("LogicalId", "Property") self.supported_resource_refs_mock.get.return_value = resolved_value @@ -206,12 +172,8 @@ def test_must_replace_refs(self, _parse_resource_reference_mock): @patch.object(RefAction, "_parse_resource_reference") def test_handle_unsupported_references(self, _parse_resource_reference_mock): - input = { - "Ref": "LogicalId.Property" - } - expected = { - "Ref": "LogicalId.Property" - } + input = {"Ref": "LogicalId.Property"} + expected = {"Ref": "LogicalId.Property"} _parse_resource_reference_mock.return_value = ("LogicalId", "Property") self.supported_resource_refs_mock.get.return_value = None @@ -223,12 +185,8 @@ def test_handle_unsupported_references(self, _parse_resource_reference_mock): @patch.object(RefAction, "_parse_resource_reference") def test_handle_unparsable_reference_value(self, _parse_resource_reference_mock): - input = { - "Ref": "some value" - } - expected = { - "Ref": "some value" - } + input = {"Ref": "some value"} + expected = {"Ref": "some value"} _parse_resource_reference_mock.return_value = (None, None) @@ -240,32 +198,23 @@ def test_handle_unparsable_reference_value(self, _parse_resource_reference_mock) @patch.object(RefAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - input = { - "Ref": "key" - } - expected = { - "Ref": "key" - } + input = {"Ref": "key"} + expected = {"Ref": "key"} ref = RefAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, ref.resolve_resource_refs(input, self.supported_resource_refs_mock)) class TestRefCanResolveResourceIdRefs(TestCase): - def setUp(self): self.supported_resource_id_refs_mock = Mock() self.ref = RefAction() def test_must_replace_refs(self): resolved_value = "NewLogicalId" - input = { - "Ref": "LogicalId" - } - expected = { - "Ref": resolved_value - } + input = {"Ref": "LogicalId"} + expected = {"Ref": resolved_value} self.supported_resource_id_refs_mock.get.return_value = resolved_value output = self.ref.resolve_resource_id_refs(input, self.supported_resource_id_refs_mock) @@ -274,12 +223,8 @@ def test_must_replace_refs(self): self.supported_resource_id_refs_mock.get.assert_called_once_with("LogicalId") def test_handle_unsupported_references(self): - input = { - "Ref": "OtherLogicalId.Property" - } - expected = { - "Ref": "OtherLogicalId.Property" - } + input = {"Ref": "OtherLogicalId.Property"} + expected = {"Ref": "OtherLogicalId.Property"} self.supported_resource_id_refs_mock.get.return_value = None @@ -289,29 +234,19 @@ def test_handle_unsupported_references(self): @patch.object(RefAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - input = { - "Ref": "key" - } - expected = { - "Ref": "key" - } + input = {"Ref": "key"} + expected = {"Ref": "key"} ref = RefAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, ref.resolve_resource_id_refs(input, self.supported_resource_id_refs_mock)) -class TestSubCanResolveParameterRefs(TestCase): +class TestSubCanResolveParameterRefs(TestCase): def test_must_resolve_string_value(self): - parameters = { - "key1": "value1" - } - input = { - "Fn::Sub": "Hello ${key1}" - } - expected = { - "Fn::Sub": "Hello value1" - } + parameters = {"key1": "value1"} + input = {"Fn::Sub": "Hello ${key1}"} + expected = {"Fn::Sub": "Hello value1"} sub = SubAction() result = sub.resolve_parameter_refs(input, parameters) @@ -319,16 +254,10 @@ def test_must_resolve_string_value(self): self.assertEqual(expected, result) def test_must_resolve_array_value(self): - parameters = { - "key1": "value1" - } - input = { - "Fn::Sub": ["Hello ${key1} ${a}", {"a":"b"}] - } + parameters = {"key1": "value1"} + input = {"Fn::Sub": ["Hello ${key1} ${a}", {"a": "b"}]} - expected = { - "Fn::Sub": ["Hello value1 ${a}", {"a": "b"}] - } + expected = {"Fn::Sub": ["Hello value1 ${a}", {"a": "b"}]} sub = SubAction() result = sub.resolve_parameter_refs(input, parameters) @@ -337,31 +266,18 @@ def test_must_resolve_array_value(self): @patch.object(SubAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - parameters = { - "key": "value" - } - input = { - "Fn::Sub": "${key}" - } - expected = { - "Fn::Sub": "${key}" - } + parameters = {"key": "value"} + input = {"Fn::Sub": "${key}"} + expected = {"Fn::Sub": "${key}"} sub = SubAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, sub.resolve_parameter_refs(input, parameters)) def test_sub_all_refs_multiple_references(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": "hello ${key1} ${key2} ${key1}${key2} ${unknown} ${key1.attr}" - } - expected = { - "Fn::Sub": "hello value1 value2 value1value2 ${unknown} ${key1.attr}" - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": "hello ${key1} ${key2} ${key1}${key2} ${unknown} ${key1.attr}"} + expected = {"Fn::Sub": "hello value1 value2 value1value2 ${unknown} ${key1.attr}"} sub = SubAction() result = sub.resolve_parameter_refs(input, parameters) @@ -369,13 +285,8 @@ def test_sub_all_refs_multiple_references(self): self.assertEqual(expected, result) def test_sub_all_refs_with_literals(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": "hello ${key1} ${key2} ${!key1} ${!key2}" - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": "hello ${key1} ${key2} ${!key1} ${!key2}"} expected = { # ${! is the prefix for literal. These won't be substituted "Fn::Sub": "hello value1 value2 ${!key1} ${!key2}" @@ -387,16 +298,9 @@ def test_sub_all_refs_with_literals(self): self.assertEqual(expected, result) def test_sub_all_refs_with_list_input(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": ["key1", "key2"] - } - expected = { - "Fn::Sub": ["key1", "key2"] - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": ["key1", "key2"]} + expected = {"Fn::Sub": ["key1", "key2"]} sub = SubAction() result = sub.resolve_parameter_refs(input, parameters) @@ -404,16 +308,9 @@ def test_sub_all_refs_with_list_input(self): self.assertEqual(expected, result) def test_sub_all_refs_with_dict_input(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": {"a": "key1", "b": "key2"} - } - expected = { - "Fn::Sub": {"a": "key1", "b": "key2"} - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": {"a": "key1", "b": "key2"}} + expected = {"Fn::Sub": {"a": "key1", "b": "key2"}} sub = SubAction() result = sub.resolve_parameter_refs(input, parameters) @@ -421,24 +318,17 @@ def test_sub_all_refs_with_dict_input(self): self.assertEqual(expected, result) def test_sub_all_refs_with_pseudo_parameters(self): - parameters = { - "key1": "value1", - "AWS::Region": "ap-southeast-1" - } - input = { - "Fn::Sub": "hello ${AWS::Region} ${key1}" - } - expected = { - "Fn::Sub": "hello ap-southeast-1 value1" - } + parameters = {"key1": "value1", "AWS::Region": "ap-southeast-1"} + input = {"Fn::Sub": "hello ${AWS::Region} ${key1}"} + expected = {"Fn::Sub": "hello ap-southeast-1 value1"} sub = SubAction() result = sub.resolve_parameter_refs(input, parameters) self.assertEqual(expected, result) -class TestSubInternalMethods(TestCase): +class TestSubInternalMethods(TestCase): @patch.object(SubAction, "_sub_all_refs") def test_handle_sub_value_must_call_handler_on_string(self, sub_all_refs_mock): input = "sub string" @@ -467,8 +357,8 @@ def test_handle_sub_value_must_call_handler_on_array(self, sub_all_refs_mock): @patch.object(SubAction, "_sub_all_refs") def test_handle_sub_value_must_skip_no_string(self, sub_all_refs_mock): - input = [{"a":"b"}] - expected = [{"a":"b"}] + input = [{"a": "b"}] + expected = [{"a": "b"}] handler_mock = Mock() sub = SubAction() @@ -504,8 +394,8 @@ def test_must_skip_invalid_input_dict(self, sub_all_refs_mock): handler_mock.assert_not_called() sub_all_refs_mock.assert_not_called() -class TestSubCanResolveResourceRefs(TestCase): +class TestSubCanResolveResourceRefs(TestCase): def setUp(self): self.supported_resource_refs = SupportedResourceReferences() self.supported_resource_refs.add("id1", "prop1", "value1") @@ -517,12 +407,8 @@ def setUp(self): def test_must_resolve_string_value(self): - input = { - "Fn::Sub": self.input_sub_value - } - expected = { - "Fn::Sub": self.expected_output_sub_value - } + input = {"Fn::Sub": self.input_sub_value} + expected = {"Fn::Sub": self.expected_output_sub_value} sub = SubAction() result = sub.resolve_resource_refs(input, self.supported_resource_refs) @@ -530,13 +416,9 @@ def test_must_resolve_string_value(self): self.assertEqual(expected, result) def test_must_resolve_array_value(self): - input = { - "Fn::Sub": [self.input_sub_value, {"unknown":"a"}] - } + input = {"Fn::Sub": [self.input_sub_value, {"unknown": "a"}]} - expected = { - "Fn::Sub": [self.expected_output_sub_value, {"unknown": "a"}] - } + expected = {"Fn::Sub": [self.expected_output_sub_value, {"unknown": "a"}]} sub = SubAction() result = sub.resolve_resource_refs(input, self.supported_resource_refs) @@ -544,16 +426,9 @@ def test_must_resolve_array_value(self): self.assertEqual(expected, result) def test_sub_all_refs_with_list_input(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": ["key1", "key2"] - } - expected = { - "Fn::Sub": ["key1", "key2"] - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": ["key1", "key2"]} + expected = {"Fn::Sub": ["key1", "key2"]} sub = SubAction() result = sub.resolve_resource_refs(input, parameters) @@ -561,16 +436,9 @@ def test_sub_all_refs_with_list_input(self): self.assertEqual(expected, result) def test_sub_all_refs_with_dict_input(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": {"a": "key1", "b": "key2"} - } - expected = { - "Fn::Sub": {"a": "key1", "b": "key2"} - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": {"a": "key1", "b": "key2"}} + expected = {"Fn::Sub": {"a": "key1", "b": "key2"}} sub = SubAction() result = sub.resolve_resource_refs(input, parameters) @@ -579,39 +447,31 @@ def test_sub_all_refs_with_dict_input(self): @patch.object(SubAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - parameters = { - "key": "value" - } - input = { - "Fn::Sub": "${key}" - } - expected = { - "Fn::Sub": "${key}" - } + parameters = {"key": "value"} + input = {"Fn::Sub": "${key}"} + expected = {"Fn::Sub": "${key}"} sub = SubAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, sub.resolve_resource_refs(input, parameters)) -class TestSubCanResolveResourceIdRefs(TestCase): +class TestSubCanResolveResourceIdRefs(TestCase): def setUp(self): self.supported_resource_id_refs = {} self.supported_resource_id_refs["id1"] = "newid1" self.supported_resource_id_refs["id2"] = "newid2" self.supported_resource_id_refs["id3"] = "newid3" - self.input_sub_value = "Hello ${id1} ${id2}${id3} ${id1.arn} ${id2.arn.name.foo} ${!id1.prop1} ${unknown} ${some.arn} World" + self.input_sub_value = ( + "Hello ${id1} ${id2}${id3} ${id1.arn} ${id2.arn.name.foo} ${!id1.prop1} ${unknown} ${some.arn} World" + ) self.expected_output_sub_value = "Hello ${newid1} ${newid2}${newid3} ${newid1.arn} ${newid2.arn.name.foo} ${!id1.prop1} ${unknown} ${some.arn} World" def test_must_resolve_string_value(self): - input = { - "Fn::Sub": self.input_sub_value - } - expected = { - "Fn::Sub": self.expected_output_sub_value - } + input = {"Fn::Sub": self.input_sub_value} + expected = {"Fn::Sub": self.expected_output_sub_value} sub = SubAction() result = sub.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -619,13 +479,9 @@ def test_must_resolve_string_value(self): self.assertEqual(expected, result) def test_must_resolve_array_value(self): - input = { - "Fn::Sub": [self.input_sub_value, {"unknown":"a"}] - } + input = {"Fn::Sub": [self.input_sub_value, {"unknown": "a"}]} - expected = { - "Fn::Sub": [self.expected_output_sub_value, {"unknown": "a"}] - } + expected = {"Fn::Sub": [self.expected_output_sub_value, {"unknown": "a"}]} sub = SubAction() result = sub.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -633,16 +489,9 @@ def test_must_resolve_array_value(self): self.assertEqual(expected, result) def test_sub_all_refs_with_list_input(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": ["key1", "key2"] - } - expected = { - "Fn::Sub": ["key1", "key2"] - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": ["key1", "key2"]} + expected = {"Fn::Sub": ["key1", "key2"]} sub = SubAction() result = sub.resolve_resource_id_refs(input, parameters) @@ -650,16 +499,9 @@ def test_sub_all_refs_with_list_input(self): self.assertEqual(expected, result) def test_sub_all_refs_with_dict_input(self): - parameters = { - "key1": "value1", - "key2": "value2" - } - input = { - "Fn::Sub": {"a": "key1", "b": "key2"} - } - expected = { - "Fn::Sub": {"a": "key1", "b": "key2"} - } + parameters = {"key1": "value1", "key2": "value2"} + input = {"Fn::Sub": {"a": "key1", "b": "key2"}} + expected = {"Fn::Sub": {"a": "key1", "b": "key2"}} sub = SubAction() result = sub.resolve_resource_id_refs(input, parameters) @@ -668,23 +510,16 @@ def test_sub_all_refs_with_dict_input(self): @patch.object(SubAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - parameters = { - "key": "value" - } - input = { - "Fn::Sub": "${key}" - } - expected = { - "Fn::Sub": "${key}" - } + parameters = {"key": "value"} + input = {"Fn::Sub": "${key}"} + expected = {"Fn::Sub": "${key}"} sub = SubAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, sub.resolve_resource_id_refs(input, parameters)) class TestGetAttCanResolveParameterRefs(TestCase): - def test_must_do_nothing(self): # Parameter references are not resolved by GetAtt input = "foo" @@ -696,19 +531,14 @@ def test_must_do_nothing(self): class TestGetAttCanResolveResourceRefs(TestCase): - def setUp(self): self.supported_resource_refs = SupportedResourceReferences() self.supported_resource_refs.add("id1", "prop1", "value1") def test_must_resolve_simple_refs(self): - input = { - "Fn::GetAtt": ["id1.prop1", "Arn"] - } + input = {"Fn::GetAtt": ["id1.prop1", "Arn"]} - expected = { - "Fn::GetAtt": ["value1", "Arn"] - } + expected = {"Fn::GetAtt": ["value1", "Arn"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -716,13 +546,9 @@ def test_must_resolve_simple_refs(self): self.assertEqual(expected, output) def test_must_resolve_refs_with_many_attributes(self): - input = { - "Fn::GetAtt": ["id1.prop1", "Arn1", "Arn2", "Arn3"] - } + input = {"Fn::GetAtt": ["id1.prop1", "Arn1", "Arn2", "Arn3"]} - expected = { - "Fn::GetAtt": ["value1", "Arn1", "Arn2", "Arn3"] - } + expected = {"Fn::GetAtt": ["value1", "Arn1", "Arn2", "Arn3"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -735,9 +561,7 @@ def test_must_resolve_with_splitted_resource_refs(self): "Fn::GetAtt": ["id1", "prop1", "Arn1", "Arn2", "Arn3"] } - expected = { - "Fn::GetAtt": ["value1", "Arn1", "Arn2", "Arn3"] - } + expected = {"Fn::GetAtt": ["value1", "Arn1", "Arn2", "Arn3"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -750,9 +574,7 @@ def test_must_ignore_refs_without_attributes(self): "Fn::GetAtt": ["id1", "prop1"] } - expected = { - "Fn::GetAtt": ["value1"] - } + expected = {"Fn::GetAtt": ["value1"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -765,9 +587,7 @@ def test_must_ignore_refs_without_attributes_in_concatenated_form(self): "Fn::GetAtt": ["id1.prop1"] } - expected = { - "Fn::GetAtt": ["id1.prop1"] - } + expected = {"Fn::GetAtt": ["id1.prop1"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -780,9 +600,7 @@ def test_must_ignore_invalid_value_array(self): "Fn::GetAtt": ["id1"] } - expected = { - "Fn::GetAtt": ["id1"] - } + expected = {"Fn::GetAtt": ["id1"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -795,9 +613,7 @@ def test_must_ignore_invalid_value_type(self): "Fn::GetAtt": {"a": "b"} } - expected = { - "Fn::GetAtt": {"a": "b"} - } + expected = {"Fn::GetAtt": {"a": "b"}} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -805,12 +621,8 @@ def test_must_ignore_invalid_value_type(self): self.assertEqual(expected, output) def test_must_ignore_missing_properties_with_dot_after(self): - input = { - "Fn::GetAtt": ["id1.", "foo"] - } - expected = { - "Fn::GetAtt": ["id1.", "foo"] - } + input = {"Fn::GetAtt": ["id1.", "foo"]} + expected = {"Fn::GetAtt": ["id1.", "foo"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -818,12 +630,8 @@ def test_must_ignore_missing_properties_with_dot_after(self): self.assertEqual(expected, output) def test_must_ignore_missing_properties_with_dot_before(self): - input = { - "Fn::GetAtt": [".id1", "foo"] - } - expected = { - "Fn::GetAtt": [".id1", "foo"] - } + input = {"Fn::GetAtt": [".id1", "foo"]} + expected = {"Fn::GetAtt": [".id1", "foo"]} getatt = GetAttAction() output = getatt.resolve_resource_refs(input, self.supported_resource_refs) @@ -832,32 +640,23 @@ def test_must_ignore_missing_properties_with_dot_before(self): @patch.object(GetAttAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - input = { - "Fn::GetAtt": ["id1", "prop1"] - } - expected = { - "Fn::GetAtt": ["id1", "prop1"] - } + input = {"Fn::GetAtt": ["id1", "prop1"]} + expected = {"Fn::GetAtt": ["id1", "prop1"]} getatt = GetAttAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, getatt.resolve_resource_refs(input, self.supported_resource_refs)) class TestGetAttCanResolveResourceIdRefs(TestCase): - def setUp(self): self.supported_resource_id_refs = {} - self.supported_resource_id_refs['id1'] = "value1" + self.supported_resource_id_refs["id1"] = "value1" def test_must_resolve_simple_refs(self): - input = { - "Fn::GetAtt": ["id1", "Arn"] - } + input = {"Fn::GetAtt": ["id1", "Arn"]} - expected = { - "Fn::GetAtt": ["value1", "Arn"] - } + expected = {"Fn::GetAtt": ["value1", "Arn"]} getatt = GetAttAction() output = getatt.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -865,13 +664,9 @@ def test_must_resolve_simple_refs(self): self.assertEqual(expected, output) def test_must_resolve_refs_with_many_attributes(self): - input = { - "Fn::GetAtt": ["id1", "Arn1", "Arn2", "Arn3"] - } + input = {"Fn::GetAtt": ["id1", "Arn1", "Arn2", "Arn3"]} - expected = { - "Fn::GetAtt": ["value1", "Arn1", "Arn2", "Arn3"] - } + expected = {"Fn::GetAtt": ["value1", "Arn1", "Arn2", "Arn3"]} getatt = GetAttAction() output = getatt.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -884,9 +679,7 @@ def test_must_ignore_invalid_value_array(self): "Fn::GetAtt": ["id1"] } - expected = { - "Fn::GetAtt": ["id1"] - } + expected = {"Fn::GetAtt": ["id1"]} getatt = GetAttAction() output = getatt.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -899,9 +692,7 @@ def test_must_ignore_invalid_value_type(self): "Fn::GetAtt": {"a": "b"} } - expected = { - "Fn::GetAtt": {"a": "b"} - } + expected = {"Fn::GetAtt": {"a": "b"}} getatt = GetAttAction() output = getatt.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -909,12 +700,8 @@ def test_must_ignore_invalid_value_type(self): self.assertEqual(expected, output) def test_must_ignore_missing_properties_with_dot_before(self): - input = { - "Fn::GetAtt": [".id1", "foo"] - } - expected = { - "Fn::GetAtt": [".id1", "foo"] - } + input = {"Fn::GetAtt": [".id1", "foo"]} + expected = {"Fn::GetAtt": [".id1", "foo"]} getatt = GetAttAction() output = getatt.resolve_resource_id_refs(input, self.supported_resource_id_refs) @@ -923,210 +710,98 @@ def test_must_ignore_missing_properties_with_dot_before(self): @patch.object(GetAttAction, "can_handle") def test_return_value_if_cannot_handle(self, can_handle_mock): - input = { - "Fn::GetAtt": ["id1", "Arn"] - } - expected = { - "Fn::GetAtt": ["id1", "Arn"] - } + input = {"Fn::GetAtt": ["id1", "Arn"]} + expected = {"Fn::GetAtt": ["id1", "Arn"]} getatt = GetAttAction() - can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input + can_handle_mock.return_value = False # Simulate failure to handle the input. Result should be same as input self.assertEqual(expected, getatt.resolve_resource_id_refs(input, self.supported_resource_id_refs)) class TestFindInMapCanResolveParameterRefs(TestCase): - def setUp(self): self.ref = FindInMapAction() @patch.object(FindInMapAction, "can_handle") def test_cannot_handle(self, can_handle_mock): - input = { - "Fn::FindInMap": ["a", "b", "c"] - } + input = {"Fn::FindInMap": ["a", "b", "c"]} can_handle_mock.return_value = False output = self.ref.resolve_parameter_refs(input, {}) self.assertEqual(input, output) def test_value_not_list(self): - input = { - "Fn::FindInMap": "a string" - } + input = {"Fn::FindInMap": "a string"} with self.assertRaises(InvalidDocumentException): self.ref.resolve_parameter_refs(input, {}) def test_value_not_list_of_length_three(self): - input = { - "Fn::FindInMap": ["a string"] - } + input = {"Fn::FindInMap": ["a string"]} with self.assertRaises(InvalidDocumentException): self.ref.resolve_parameter_refs(input, {}) def test_mapping_not_string(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": [["MapA"], "TopKey2", "SecondKey1"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": [["MapA"], "TopKey2", "SecondKey1"]} output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(input, output) def test_top_level_key_not_string(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": ["MapA", ["TopKey2"], "SecondKey1"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": ["MapA", ["TopKey2"], "SecondKey1"]} output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(input, output) def test_second_level_key_not_string(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": ["MapA", "TopKey1", ["SecondKey2"]] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": ["MapA", "TopKey1", ["SecondKey2"]]} output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(input, output) def test_mapping_not_found(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": ["MapB", "TopKey2", "SecondKey1"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": ["MapB", "TopKey2", "SecondKey1"]} output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(input, output) def test_top_level_key_not_found(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": ["MapA", "TopKey3", "SecondKey1"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": ["MapA", "TopKey3", "SecondKey1"]} output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(input, output) def test_second_level_key_not_found(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": ["MapA", "TopKey1", "SecondKey1"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": ["MapA", "TopKey1", "SecondKey1"]} output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(input, output) def test_one_level_find_in_mappings(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "value4" - } - } - } - input = { - "Fn::FindInMap": ["MapA", "TopKey2", "SecondKey1"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "value4"}}} + input = {"Fn::FindInMap": ["MapA", "TopKey2", "SecondKey1"]} expected = "value4" output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(expected, output) def test_nested_find_in_mappings(self): - mappings = { - "MapA":{ - "TopKey1": { - "SecondKey2": "value3" - }, - "TopKey2": { - "SecondKey1": "TopKey1" - } - } - } - input = { - "Fn::FindInMap": ["MapA", {"Fn::FindInMap": ["MapA", "TopKey2", "SecondKey1"]}, "SecondKey2"] - } + mappings = {"MapA": {"TopKey1": {"SecondKey2": "value3"}, "TopKey2": {"SecondKey1": "TopKey1"}}} + input = {"Fn::FindInMap": ["MapA", {"Fn::FindInMap": ["MapA", "TopKey2", "SecondKey1"]}, "SecondKey2"]} expected = "value3" output = self.ref.resolve_parameter_refs(input, mappings) self.assertEqual(expected, output) def test_nested_find_in_multiple_mappings(self): - mappings = { - "MapA":{ - "ATopKey1": { - "ASecondKey2": "value3" - } - }, - "MapB": { - "BTopKey1": { - "BSecondKey2": "ATopKey1" - } - } - } - input = { - "Fn::FindInMap": ["MapA", {"Fn::FindInMap": ["MapB", "BTopKey1", "BSecondKey2"]}, "ASecondKey2"] - } + mappings = {"MapA": {"ATopKey1": {"ASecondKey2": "value3"}}, "MapB": {"BTopKey1": {"BSecondKey2": "ATopKey1"}}} + input = {"Fn::FindInMap": ["MapA", {"Fn::FindInMap": ["MapB", "BTopKey1", "BSecondKey2"]}, "ASecondKey2"]} expected = "value3" output = self.ref.resolve_parameter_refs(input, mappings) diff --git a/tests/intrinsics/test_resolver.py b/tests/intrinsics/test_resolver.py index b41ede2ca1..29d2be39b0 100644 --- a/tests/intrinsics/test_resolver.py +++ b/tests/intrinsics/test_resolver.py @@ -3,62 +3,32 @@ from samtranslator.intrinsics.resolver import IntrinsicsResolver from samtranslator.intrinsics.actions import Action -class TestParameterReferenceResolution(TestCase): +class TestParameterReferenceResolution(TestCase): def setUp(self): - self.parameter_values = { - "param1": "value1", - "param2": "value2", - "param3": "value3" - } + self.parameter_values = {"param1": "value1", "param2": "value2", "param3": "value3"} self.resolver = IntrinsicsResolver(self.parameter_values) def test_must_resolve_top_level_direct_refs(self): - input = { - "key1": { - "Ref": "param1" - }, - "key2": { - "Ref": "param2" - }, - "key3": { - "a": "b" - } - } + input = {"key1": {"Ref": "param1"}, "key2": {"Ref": "param2"}, "key3": {"a": "b"}} expected = { "key1": self.parameter_values["param1"], "key2": self.parameter_values["param2"], - "key3": { - "a": "b" - } + "key3": {"a": "b"}, } output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) def test_must_resolve_nested_refs(self): - input = { - "key1": { - "sub1": { - "sub2": { - "sub3": { - "Ref": "param1" - }, - "list": [1, "b", {"Ref": "param2"}] - } - } - } - } + input = {"key1": {"sub1": {"sub2": {"sub3": {"Ref": "param1"}, "list": [1, "b", {"Ref": "param2"}]}}}} expected = { "key1": { "sub1": { - "sub2": { - "sub3": self.parameter_values["param1"], - "list": [1, "b", self.parameter_values["param2"]] - } + "sub2": {"sub3": self.parameter_values["param1"], "list": [1, "b", self.parameter_values["param2"]]} } } } @@ -81,73 +51,41 @@ def test_must_resolve_array_refs(self): self.assertEqual(output, expected) def test_must_skip_unknown_refs(self): - input = { - "key1": { - "Ref": "someresource" - }, - "key2": { - "Ref": "param1" - } - } + input = {"key1": {"Ref": "someresource"}, "key2": {"Ref": "param1"}} - expected = { - "key1": { - "Ref": "someresource" - }, - "key2": self.parameter_values["param1"] - } + expected = {"key1": {"Ref": "someresource"}, "key2": self.parameter_values["param1"]} output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) def test_must_resolve_inside_sub_strings(self): - input = { - "Fn::Sub": "prefix ${param1} ${param2} ${param3} ${param1} suffix" - } + input = {"Fn::Sub": "prefix ${param1} ${param2} ${param3} ${param1} suffix"} - expected = { - "Fn::Sub": "prefix value1 value2 value3 value1 suffix" - } + expected = {"Fn::Sub": "prefix value1 value2 value3 value1 suffix"} output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) def test_must_skip_over_sub_literals(self): - input = { - "Fn::Sub": "prefix ${!MustNotBeReplaced} suffix" - } + input = {"Fn::Sub": "prefix ${!MustNotBeReplaced} suffix"} - expected = { - "Fn::Sub": "prefix ${!MustNotBeReplaced} suffix" - } + expected = {"Fn::Sub": "prefix ${!MustNotBeReplaced} suffix"} output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) def test_must_resolve_refs_inside_other_intrinsics(self): - input = { - "key1": { - "Fn::Join": ["-", [{"Ref": "param1"}, "some other value"]] - } - } + input = {"key1": {"Fn::Join": ["-", [{"Ref": "param1"}, "some other value"]]}} - expected = { - "key1": { - "Fn::Join": ["-", [self.parameter_values["param1"], "some other value"]] - } - } + expected = {"key1": {"Fn::Join": ["-", [self.parameter_values["param1"], "some other value"]]}} output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) def test_skip_invalid_values_for_ref(self): - input = { - "Ref": ["ref cannot have list value"] - } + input = {"Ref": ["ref cannot have list value"]} - expected = { - "Ref": ["ref cannot have list value"] - } + expected = {"Ref": ["ref cannot have list value"]} output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) @@ -157,9 +95,7 @@ def test_skip_invalid_values_for_sub(self): "Fn::Sub": [{"a": "b"}] } - expected = { - "Fn::Sub": [{"a": "b"}] - } + expected = {"Fn::Sub": [{"a": "b"}]} output = self.resolver.resolve_parameter_refs(input) self.assertEqual(output, expected) @@ -170,7 +106,7 @@ def test_throw_on_empty_parameters(self): def test_throw_on_non_dict_parameters(self): with self.assertRaises(TypeError): - IntrinsicsResolver([1,2,3]).resolve_parameter_refs({}) + IntrinsicsResolver([1, 2, 3]).resolve_parameter_refs({}) def test_short_circuit_on_empty_parameters(self): resolver = IntrinsicsResolver({}) @@ -181,8 +117,8 @@ def test_short_circuit_on_empty_parameters(self): self.assertEqual(resolver.resolve_parameter_refs(input), expected) resolver._try_resolve_parameter_refs.assert_not_called() -class TestResourceReferenceResolution(TestCase): +class TestResourceReferenceResolution(TestCase): def setUp(self): self.resolver = IntrinsicsResolver({}) @@ -197,9 +133,7 @@ def test_resolve_sam_resource_refs(self, traverse_mock, try_resolve_mock): def test_resolve_sam_resource_refs_on_supported_intrinsic(self): action_mock = Mock() - self.resolver.supported_intrinsics = { - "foo": action_mock - } + self.resolver.supported_intrinsics = {"foo": action_mock} input = {"foo": "bar"} supported_refs = Mock() @@ -208,9 +142,7 @@ def test_resolve_sam_resource_refs_on_supported_intrinsic(self): def test_resolve_sam_resource_refs_on_unknown_intrinsic(self): action_mock = Mock() - self.resolver.supported_intrinsics = { - "foo": action_mock - } + self.resolver.supported_intrinsics = {"foo": action_mock} input = {"a": "b"} expected = {"a": "b"} supported_refs = Mock() @@ -228,6 +160,7 @@ def test_short_circuit_on_empty_parameters(self): self.assertEqual(resolver.resolve_sam_resource_refs(input, {}), expected) resolver._try_resolve_sam_resource_refs.assert_not_called() + class TestSupportedIntrinsics(TestCase): def test_by_default_all_intrinsics_must_be_supported(self): # Just make sure we never remove support for some intrinsic @@ -244,25 +177,17 @@ class SomeAction(Action): intrinsic_name = "IntrinsicName" action = SomeAction() - supported_intrinsics = { - "ThisCanBeAnyIntrinsicName": action - } + supported_intrinsics = {"ThisCanBeAnyIntrinsicName": action} resolver = IntrinsicsResolver({}, supported_intrinsics) - self.assertEqual(resolver.supported_intrinsics, { - "ThisCanBeAnyIntrinsicName": action - }) + self.assertEqual(resolver.supported_intrinsics, {"ThisCanBeAnyIntrinsicName": action}) def test_configure_supported_intrinsics_must_error_for_non_action_value(self): - class SomeAction(Action): intrinsic_name = "Foo" # All intrinsics must have a value to be subclass of "Action" - supported_intrinsics = { - "A": "B", - "Foo": SomeAction() - } + supported_intrinsics = {"A": "B", "Foo": SomeAction()} with self.assertRaises(TypeError): IntrinsicsResolver({}, supported_intrinsics) @@ -274,4 +199,4 @@ def test_configure_supported_intrinsics_must_error_for_none_input(self): def test_configure_supported_intrinsics_must_error_for_non_dict_input(self): with self.assertRaises(TypeError): - IntrinsicsResolver({}, [1,2,3]) + IntrinsicsResolver({}, [1, 2, 3]) diff --git a/tests/intrinsics/test_resource_refs.py b/tests/intrinsics/test_resource_refs.py index 367b58d6c2..bd92d55164 100644 --- a/tests/intrinsics/test_resource_refs.py +++ b/tests/intrinsics/test_resource_refs.py @@ -1,8 +1,8 @@ from unittest import TestCase from samtranslator.intrinsics.resource_refs import SupportedResourceReferences -class TestSupportedResourceReferences(TestCase): +class TestSupportedResourceReferences(TestCase): def test_add_multiple_properties_to_one_logicalId(self): resource_refs = SupportedResourceReferences() @@ -11,11 +11,7 @@ def test_add_multiple_properties_to_one_logicalId(self): resource_refs.add("logicalId", "property2", "value2") resource_refs.add("logicalId", "property3", "value3") - expected = { - "property1": "value1", - "property2": "value2", - "property3": "value3" - } + expected = {"property1": "value1", "property2": "value2", "property3": "value3"} self.assertEqual(expected, resource_refs.get_all("logicalId")) @@ -84,7 +80,6 @@ def test_get_on_non_existent_property(self): self.assertEqual(None, resource_refs.get("logicalId1", "SomeProperty")) self.assertEqual(None, resource_refs.get("SomeLogicalId", "property1")) - def test_len_single_resource(self): resource_refs = SupportedResourceReferences() diff --git a/tests/model/api/test_htttp_api_generator.py b/tests/model/api/test_htttp_api_generator.py index b1b4ccb54d..cee1621610 100644 --- a/tests/model/api/test_htttp_api_generator.py +++ b/tests/model/api/test_htttp_api_generator.py @@ -9,24 +9,24 @@ class TestHttpApiGenerator(TestCase): kwargs = { - 'logical_id': "HttpApiId", - 'stage_variables': None, - 'depends_on': None, - 'definition_body': None, - 'definition_uri': None, - 'stage_name': None, - 'tags': None, - 'auth': None, - 'access_log_settings': None, - 'resource_attributes': None, - 'passthrough_resource_attributes': None + "logical_id": "HttpApiId", + "stage_variables": None, + "depends_on": None, + "definition_body": None, + "definition_uri": None, + "stage_name": None, + "tags": None, + "auth": None, + "access_log_settings": None, + "resource_attributes": None, + "passthrough_resource_attributes": None, } authorizers = { "Authorizers": { "OAuth2": { "AuthorizationScopes": ["scope"], "JwtConfiguration": {"config": "value"}, - "IdentitySource": "https://example.com" + "IdentitySource": "https://example.com", } } } diff --git a/tests/model/eventsources/test_api_event_source.py b/tests/model/eventsources/test_api_event_source.py index cf7f0c9681..040e0aba0d 100644 --- a/tests/model/eventsources/test_api_event_source.py +++ b/tests/model/eventsources/test_api_event_source.py @@ -4,8 +4,8 @@ from samtranslator.model.eventsources.push import Api from samtranslator.model.lambda_ import LambdaFunction, LambdaPermission -class ApiEventSource(TestCase): +class ApiEventSource(TestCase): def setUp(self): self.logical_id = "Api" @@ -64,7 +64,9 @@ def test_get_permission_with_path_parameter(self): except AttributeError: self.fail("Permission class isn't valid") - self.assertEqual(arn, "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${__ApiId__}/${__Stage__}/GET/foo/*/bar") + self.assertEqual( + arn, "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${__ApiId__}/${__Stage__}/GET/foo/*/bar" + ) @patch("boto3.session.Session.region_name", "eu-west-2") def test_get_permission_with_proxy_resource(self): @@ -79,7 +81,9 @@ def test_get_permission_with_proxy_resource(self): except AttributeError: self.fail("Permission class isn't valid") - self.assertEqual(arn, "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${__ApiId__}/${__Stage__}/GET/foo/*") + self.assertEqual( + arn, "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${__ApiId__}/${__Stage__}/GET/foo/*" + ) @patch("boto3.session.Session.region_name", "eu-west-2") def test_get_permission_with_just_slash(self): @@ -97,10 +101,7 @@ def test_get_permission_with_just_slash(self): self.assertEqual(arn, "arn:aws:execute-api:${AWS::Region}:${AWS::AccountId}:${__ApiId__}/${__Stage__}/GET/") def _extract_path_from_arn(self, logical_id, perm): - arn = perm.to_dict().get(logical_id, {}) \ - .get("Properties", {}) \ - .get("SourceArn", {}) \ - .get("Fn::Sub", [])[0] + arn = perm.to_dict().get(logical_id, {}).get("Properties", {}).get("SourceArn", {}).get("Fn::Sub", [])[0] if arn is None: raise AttributeError("Arn not found") diff --git a/tests/model/eventsources/test_cloudwatchlogs_event_source.py b/tests/model/eventsources/test_cloudwatchlogs_event_source.py index 623824d0eb..4f8f8b7079 100644 --- a/tests/model/eventsources/test_cloudwatchlogs_event_source.py +++ b/tests/model/eventsources/test_cloudwatchlogs_event_source.py @@ -2,48 +2,48 @@ from unittest import TestCase from samtranslator.model.eventsources.cloudwatchlogs import CloudWatchLogs -class CloudWatchLogsEventSource(TestCase): +class CloudWatchLogsEventSource(TestCase): def setUp(self): - self.logical_id = 'LogProcessor' + self.logical_id = "LogProcessor" self.cloudwatch_logs_event_source = CloudWatchLogs(self.logical_id) - self.cloudwatch_logs_event_source.LogGroupName = 'MyLogGroup' - self.cloudwatch_logs_event_source.FilterPattern = 'Fizbo' + self.cloudwatch_logs_event_source.LogGroupName = "MyLogGroup" + self.cloudwatch_logs_event_source.FilterPattern = "Fizbo" self.function = Mock() self.function.get_runtime_attr = Mock() - self.function.get_runtime_attr.return_value = 'arn:aws:mock' + self.function.get_runtime_attr.return_value = "arn:aws:mock" self.function.resource_attributes = {} self.function.get_passthrough_resource_attributes = Mock() self.function.get_passthrough_resource_attributes.return_value = {} self.permission = Mock() - self.permission.logical_id = 'LogProcessorPermission' + self.permission.logical_id = "LogProcessorPermission" - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_get_source_arn(self): source_arn = self.cloudwatch_logs_event_source.get_source_arn() - expected_source_arn = {'Fn::Sub': [ - 'arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:${__LogGroupName__}:*', {'__LogGroupName__': 'MyLogGroup'}]} + expected_source_arn = { + "Fn::Sub": [ + "arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:${__LogGroupName__}:*", + {"__LogGroupName__": "MyLogGroup"}, + ] + } self.assertEqual(source_arn, expected_source_arn) def test_get_subscription_filter(self): - subscription_filter = self.cloudwatch_logs_event_source.get_subscription_filter( - self.function, self.permission) - self.assertEqual(subscription_filter.LogGroupName, 'MyLogGroup') - self.assertEqual(subscription_filter.FilterPattern, 'Fizbo') - self.assertEqual(subscription_filter.DestinationArn, 'arn:aws:mock') + subscription_filter = self.cloudwatch_logs_event_source.get_subscription_filter(self.function, self.permission) + self.assertEqual(subscription_filter.LogGroupName, "MyLogGroup") + self.assertEqual(subscription_filter.FilterPattern, "Fizbo") + self.assertEqual(subscription_filter.DestinationArn, "arn:aws:mock") - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_to_cloudformation_returns_permission_and_subscription_filter_resources(self): - resources = self.cloudwatch_logs_event_source.to_cloudformation( - function=self.function) + resources = self.cloudwatch_logs_event_source.to_cloudformation(function=self.function) self.assertEqual(len(resources), 2) - self.assertEqual(resources[0].resource_type, - 'AWS::Lambda::Permission') - self.assertEqual(resources[1].resource_type, - 'AWS::Logs::SubscriptionFilter') + self.assertEqual(resources[0].resource_type, "AWS::Lambda::Permission") + self.assertEqual(resources[1].resource_type, "AWS::Logs::SubscriptionFilter") def test_to_cloudformation_throws_when_no_function(self): self.assertRaises(TypeError, self.cloudwatch_logs_event_source.to_cloudformation) diff --git a/tests/model/eventsources/test_sns_event_source.py b/tests/model/eventsources/test_sns_event_source.py index f877fe9bd7..711f2ce221 100644 --- a/tests/model/eventsources/test_sns_event_source.py +++ b/tests/model/eventsources/test_sns_event_source.py @@ -4,61 +4,53 @@ class SnsEventSource(TestCase): - def setUp(self): - self.logical_id = 'NotificationsProcessor' + self.logical_id = "NotificationsProcessor" self.sns_event_source = SNS(self.logical_id) - self.sns_event_source.Topic = 'arn:aws:sns:MyTopic' + self.sns_event_source.Topic = "arn:aws:sns:MyTopic" self.function = Mock() self.function.get_runtime_attr = Mock() - self.function.get_runtime_attr.return_value = 'arn:aws:lambda:mock' + self.function.get_runtime_attr.return_value = "arn:aws:lambda:mock" self.function.resource_attributes = {} self.function.get_passthrough_resource_attributes = Mock() self.function.get_passthrough_resource_attributes.return_value = {} def test_to_cloudformation_returns_permission_and_subscription_resources(self): - resources = self.sns_event_source.to_cloudformation( - function=self.function) + resources = self.sns_event_source.to_cloudformation(function=self.function) self.assertEqual(len(resources), 2) - self.assertEqual(resources[0].resource_type, - 'AWS::Lambda::Permission') - self.assertEqual(resources[1].resource_type, - 'AWS::SNS::Subscription') + self.assertEqual(resources[0].resource_type, "AWS::Lambda::Permission") + self.assertEqual(resources[1].resource_type, "AWS::SNS::Subscription") subscription = resources[1] - self.assertEqual(subscription.TopicArn, 'arn:aws:sns:MyTopic') - self.assertEqual(subscription.Protocol, 'lambda') - self.assertEqual(subscription.Endpoint, 'arn:aws:lambda:mock') + self.assertEqual(subscription.TopicArn, "arn:aws:sns:MyTopic") + self.assertEqual(subscription.Protocol, "lambda") + self.assertEqual(subscription.Endpoint, "arn:aws:lambda:mock") self.assertIsNone(subscription.Region) self.assertIsNone(subscription.FilterPolicy) def test_to_cloudformation_passes_the_region(self): - region = 'us-west-2' + region = "us-west-2" self.sns_event_source.Region = region - resources = self.sns_event_source.to_cloudformation( - function=self.function) + resources = self.sns_event_source.to_cloudformation(function=self.function) self.assertEqual(len(resources), 2) - self.assertEqual(resources[1].resource_type, - 'AWS::SNS::Subscription') + self.assertEqual(resources[1].resource_type, "AWS::SNS::Subscription") subscription = resources[1] self.assertEqual(subscription.Region, region) def test_to_cloudformation_passes_the_filter_policy(self): filterPolicy = { - 'attribute1': ['value1'], - 'attribute2': ['value2', 'value3'], - 'attribute3': {'numeric': ['>=', '100']} + "attribute1": ["value1"], + "attribute2": ["value2", "value3"], + "attribute3": {"numeric": [">=", "100"]}, } self.sns_event_source.FilterPolicy = filterPolicy - resources = self.sns_event_source.to_cloudformation( - function=self.function) + resources = self.sns_event_source.to_cloudformation(function=self.function) self.assertEqual(len(resources), 2) - self.assertEqual(resources[1].resource_type, - 'AWS::SNS::Subscription') + self.assertEqual(resources[1].resource_type, "AWS::SNS::Subscription") subscription = resources[1] self.assertEqual(subscription.FilterPolicy, filterPolicy) @@ -66,9 +58,7 @@ def test_to_cloudformation_throws_when_no_function(self): self.assertRaises(TypeError, self.sns_event_source.to_cloudformation) def test_to_cloudformation_throws_when_queue_url_or_queue_arn_not_given(self): - sqsSubscription = { - 'BatchSize': 5 - } + sqsSubscription = {"BatchSize": 5} self.sns_event_source.SqsSubscription = sqsSubscription self.assertRaises(TypeError, self.sns_event_source.to_cloudformation) @@ -76,17 +66,14 @@ def test_to_cloudformation_when_sqs_subscription_disable(self): sqsSubscription = False self.sns_event_source.SqsSubscription = sqsSubscription - resources = self.sns_event_source.to_cloudformation( - function=self.function) + resources = self.sns_event_source.to_cloudformation(function=self.function) self.assertEqual(len(resources), 2) - self.assertEqual(resources[0].resource_type, - 'AWS::Lambda::Permission') - self.assertEqual(resources[1].resource_type, - 'AWS::SNS::Subscription') + self.assertEqual(resources[0].resource_type, "AWS::Lambda::Permission") + self.assertEqual(resources[1].resource_type, "AWS::SNS::Subscription") subscription = resources[1] - self.assertEqual(subscription.TopicArn, 'arn:aws:sns:MyTopic') - self.assertEqual(subscription.Protocol, 'lambda') - self.assertEqual(subscription.Endpoint, 'arn:aws:lambda:mock') + self.assertEqual(subscription.TopicArn, "arn:aws:sns:MyTopic") + self.assertEqual(subscription.Protocol, "lambda") + self.assertEqual(subscription.Endpoint, "arn:aws:lambda:mock") self.assertIsNone(subscription.Region) self.assertIsNone(subscription.FilterPolicy) diff --git a/tests/model/tags/test_resource_tagging.py b/tests/model/tags/test_resource_tagging.py index 9ffac62a86..54e7e72220 100644 --- a/tests/model/tags/test_resource_tagging.py +++ b/tests/model/tags/test_resource_tagging.py @@ -4,24 +4,20 @@ class TestResourceTagging(TestCase): - def test_get_tag_list_returns_default_tag_list_values(self): tag_list = get_tag_list(None) expected_tag_list = [] self.assertEqual(tag_list, expected_tag_list) - def test_get_tag_list_with_tag_dictionary_with_key_only(self): tag_list = get_tag_list({"key": None}) - expected_tag_list = [{"Key": "key", - "Value": ""}] + expected_tag_list = [{"Key": "key", "Value": ""}] self.assertEqual(tag_list, expected_tag_list) def test_get_tag_list_with_tag_dictionary(self): tag_list = get_tag_list({"AnotherKey": "This time with a value"}) - expected_tag_list = [{"Key": "AnotherKey", - "Value": "This time with a value"}] + expected_tag_list = [{"Key": "AnotherKey", "Value": "This time with a value"}] self.assertEqual(tag_list, expected_tag_list) diff --git a/tests/model/test_api_v2.py b/tests/model/test_api_v2.py index fc9116c84a..36d87b5e1e 100644 --- a/tests/model/test_api_v2.py +++ b/tests/model/test_api_v2.py @@ -6,20 +6,30 @@ class TestApiGatewayV2Authorizer(TestCase): - def test_create_oauth2_auth(self): - auth = ApiGatewayV2Authorizer(api_logical_id="logicalId", name="authName", - jwt_configuration={"config": "value"}, id_source="https://example.com") + auth = ApiGatewayV2Authorizer( + api_logical_id="logicalId", + name="authName", + jwt_configuration={"config": "value"}, + id_source="https://example.com", + ) self.assertEquals(auth.auth_type, "oauth2") def test_create_oidc_auth(self): - auth = ApiGatewayV2Authorizer(api_logical_id="logicalId", name="authName", open_id_connect_url="https://example.com", - jwt_configuration={"config": "value"}, id_source="https://example.com") + auth = ApiGatewayV2Authorizer( + api_logical_id="logicalId", + name="authName", + open_id_connect_url="https://example.com", + jwt_configuration={"config": "value"}, + id_source="https://example.com", + ) self.assertEquals(auth.auth_type, "openIdConnect") def test_create_authorizer_no_id_source(self): with pytest.raises(InvalidResourceException): - auth = ApiGatewayV2Authorizer(api_logical_id="logicalId", name="authName", jwt_configuration={"config": "value"}) + auth = ApiGatewayV2Authorizer( + api_logical_id="logicalId", name="authName", jwt_configuration={"config": "value"} + ) def test_create_authorizer_no_jwt_config(self): with pytest.raises(InvalidResourceException): diff --git a/tests/model/test_function_policies.py b/tests/model/test_function_policies.py index a0a82b4bd9..4e39d05814 100644 --- a/tests/model/test_function_policies.py +++ b/tests/model/test_function_policies.py @@ -5,8 +5,8 @@ from samtranslator.model.exceptions import InvalidTemplateException from samtranslator.model.intrinsics import is_intrinsic_if, is_intrinsic_no_value -class TestFunctionPolicies(TestCase): +class TestFunctionPolicies(TestCase): def setUp(self): self.policy_template_processor_mock = Mock() self.is_policy_template_mock = Mock() @@ -14,7 +14,6 @@ def setUp(self): self.function_policies = FunctionPolicies({}, self.policy_template_processor_mock) self.function_policies._is_policy_template = self.is_policy_template_mock - @patch.object(FunctionPolicies, "_get_policies") def test_initialization_must_ingest_policies_from_resource_properties(self, get_policies_mock): resource_properties = {} @@ -27,10 +26,9 @@ def test_initialization_must_ingest_policies_from_resource_properties(self, get_ get_policies_mock.assert_called_once_with(resource_properties) self.assertEqual(expected_length, len(function_policies)) - @patch.object(FunctionPolicies, "_get_policies") def test_get_must_yield_results_on_every_call(self, get_policies_mock): - resource_properties = {} # Just some input + resource_properties = {} # Just some input dummy_policy_results = ["some", "policy", "statements"] expected_results = ["some", "policy", "statements"] @@ -41,10 +39,9 @@ def test_get_must_yield_results_on_every_call(self, get_policies_mock): # `list()` will implicitly call the `get()` repeatedly because it is a generator self.assertEqual(list(function_policies.get()), expected_results) - @patch.object(FunctionPolicies, "_get_policies") def test_get_must_yield_no_results_with_no_policies(self, get_policies_mock): - resource_properties = {} # Just some input + resource_properties = {} # Just some input dummy_policy_results = [] expected_result = [] @@ -56,16 +53,12 @@ def test_get_must_yield_no_results_with_no_policies(self, get_policies_mock): self.assertEqual(list(function_policies.get()), expected_result) def test_contains_policies_must_work_for_valid_input(self): - resource_properties = { - "Policies": "some managed policy" - } + resource_properties = {"Policies": "some managed policy"} self.assertTrue(self.function_policies._contains_policies(resource_properties)) def test_contains_policies_must_ignore_resources_without_policies(self): - resource_properties = { - "some key": "value" - } + resource_properties = {"some key": "value"} self.assertFalse(self.function_policies._contains_policies(resource_properties)) @@ -81,9 +74,7 @@ def test_contains_policies_must_ignore_none_resources(self): def test_contains_policies_must_ignore_lowercase_property_name(self): # Property names are case sensitive - resource_properties = { - "policies": "some managed policy" - } + resource_properties = {"policies": "some managed policy"} self.assertFalse(self.function_policies._contains_policies(resource_properties)) @@ -96,9 +87,7 @@ def test_get_type_must_work_for_managed_policy(self): @patch("samtranslator.model.function_policies.is_instrinsic") def test_get_type_must_work_for_managed_policy_with_intrinsics(self, is_intrinsic_mock): - policy = { - "Ref": "somevalue" - } + policy = {"Ref": "somevalue"} expected = PolicyTypes.MANAGED_POLICY is_intrinsic_mock.return_value = True @@ -106,18 +95,14 @@ def test_get_type_must_work_for_managed_policy_with_intrinsics(self, is_intrinsi self.assertEqual(result, expected) def test_get_type_must_work_for_policy_statements(self): - policy = { - "Statement": "policy statements have a 'Statement' key" - } + policy = {"Statement": "policy statements have a 'Statement' key"} expected = PolicyTypes.POLICY_STATEMENT result = self.function_policies._get_type(policy) self.assertEqual(result, expected) def test_get_type_must_work_for_policy_templates(self): - policy = { - "PolicyTemplate": "some template" - } + policy = {"PolicyTemplate": "some template"} self.is_policy_template_mock.return_value = True expected = PolicyTypes.POLICY_TEMPLATE @@ -125,9 +110,7 @@ def test_get_type_must_work_for_policy_templates(self): self.assertEqual(result, expected) def test_get_type_must_ignore_invalid_policy(self): - policy = { - "not-sure-what-this-is": "value" - } + policy = {"not-sure-what-this-is": "value"} # This is also not a policy template self.is_policy_template_mock.return_value = False expected = PolicyTypes.UNKNOWN @@ -151,12 +134,10 @@ def test_get_policies_must_return_all_policies(self): {"Ref": "some managed policy"}, {"Statement": "policy statement"}, {"PolicyTemplate": "some value"}, - ["unknown", "policy"] + ["unknown", "policy"], ] - resource_properties = { - "Policies": policies - } - self.is_policy_template_mock.side_effect = [True, False] # Return True for policy template, False for the list + resource_properties = {"Policies": policies} + self.is_policy_template_mock.side_effect = [True, False] # Return True for policy template, False for the list expected = [ PolicyEntry(data="managed policy 1", type=PolicyTypes.MANAGED_POLICY), @@ -170,95 +151,60 @@ def test_get_policies_must_return_all_policies(self): self.assertEqual(result, expected) def test_get_policies_must_ignore_if_resource_does_not_contain_policy(self): - resource_properties = { - } + resource_properties = {} expected = [] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) def test_get_policies_must_ignore_if_policies_is_empty(self): - resource_properties = { - "Policies": [] - } + resource_properties = {"Policies": []} expected = [] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) def test_get_policies_must_work_for_single_policy_string(self): - resource_properties = { - "Policies": "single managed policy" - } - expected = [ - PolicyEntry(data="single managed policy", type=PolicyTypes.MANAGED_POLICY) - ] + resource_properties = {"Policies": "single managed policy"} + expected = [PolicyEntry(data="single managed policy", type=PolicyTypes.MANAGED_POLICY)] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) def test_get_policies_must_work_for_single_dict_with_managed_policy_intrinsic(self): - resource_properties = { - "Policies": { - "Ref": "some managed policy" - } - } - expected = [ - PolicyEntry(data={"Ref": "some managed policy"}, type=PolicyTypes.MANAGED_POLICY) - ] + resource_properties = {"Policies": {"Ref": "some managed policy"}} + expected = [PolicyEntry(data={"Ref": "some managed policy"}, type=PolicyTypes.MANAGED_POLICY)] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) def test_get_policies_must_work_for_single_dict_with_policy_statement(self): - resource_properties = { - "Policies": { - "Statement": "some policy statement" - } - } - expected = [ - PolicyEntry(data={"Statement": "some policy statement"}, type=PolicyTypes.POLICY_STATEMENT) - ] + resource_properties = {"Policies": {"Statement": "some policy statement"}} + expected = [PolicyEntry(data={"Statement": "some policy statement"}, type=PolicyTypes.POLICY_STATEMENT)] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) def test_get_policies_must_work_for_single_dict_of_policy_template(self): - resource_properties = { - "Policies": { - "PolicyTemplate": "some template" - } - } + resource_properties = {"Policies": {"PolicyTemplate": "some template"}} self.is_policy_template_mock.return_value = True - expected = [ - PolicyEntry(data={"PolicyTemplate": "some template"}, type=PolicyTypes.POLICY_TEMPLATE) - ] + expected = [PolicyEntry(data={"PolicyTemplate": "some template"}, type=PolicyTypes.POLICY_TEMPLATE)] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) self.is_policy_template_mock.assert_called_once_with(resource_properties["Policies"]) def test_get_policies_must_work_for_single_dict_of_invalid_policy_template(self): - resource_properties = { - "Policies": { - "InvalidPolicyTemplate": "some template" - } - } - self.is_policy_template_mock.return_value = False # Invalid policy template - expected = [ - PolicyEntry(data={"InvalidPolicyTemplate": "some template"}, type=PolicyTypes.UNKNOWN) - ] + resource_properties = {"Policies": {"InvalidPolicyTemplate": "some template"}} + self.is_policy_template_mock.return_value = False # Invalid policy template + expected = [PolicyEntry(data={"InvalidPolicyTemplate": "some template"}, type=PolicyTypes.UNKNOWN)] result = self.function_policies._get_policies(resource_properties) self.assertEqual(result, expected) self.is_policy_template_mock.assert_called_once_with({"InvalidPolicyTemplate": "some template"}) def test_get_policies_must_work_for_unknown_policy_types(self): - resource_properties = { - "Policies": [ - 1, 2, 3 - ] - } + resource_properties = {"Policies": [1, 2, 3]} expected = [ PolicyEntry(data=1, type=PolicyTypes.UNKNOWN), PolicyEntry(data=2, type=PolicyTypes.UNKNOWN), @@ -272,11 +218,7 @@ def test_get_policies_must_work_for_unknown_policy_types(self): def test_is_policy_template_must_detect_valid_policy_templates(self): template_name = "template_name" - policy = { - template_name: { - "Param1": "foo" - } - } + policy = {template_name: {"Param1": "foo"}} self.policy_template_processor_mock.has.return_value = True function_policies = FunctionPolicies({}, self.policy_template_processor_mock) @@ -286,7 +228,7 @@ def test_is_policy_template_must_detect_valid_policy_templates(self): self.policy_template_processor_mock.has.assert_called_once_with(template_name) def test_is_policy_template_must_ignore_non_dict_policies(self): - policy = [1,2,3] + policy = [1, 2, 3] self.policy_template_processor_mock.has.return_value = True function_policies = FunctionPolicies({}, self.policy_template_processor_mock) @@ -303,10 +245,7 @@ def test_is_policy_template_must_ignore_none_policies(self): def test_is_policy_template_must_ignore_dict_with_two_keys(self): template_name = "template_name" - policy = { - template_name: {"param1": "foo"}, - "A": "B" - } + policy = {template_name: {"param1": "foo"}, "A": "B"} self.policy_template_processor_mock.has.return_value = True @@ -315,9 +254,7 @@ def test_is_policy_template_must_ignore_dict_with_two_keys(self): def test_is_policy_template_must_ignore_non_policy_templates(self): template_name = "template_name" - policy = { - template_name: {"param1": "foo"} - } + policy = {template_name: {"param1": "foo"}} self.policy_template_processor_mock.has.return_value = False @@ -327,56 +264,38 @@ def test_is_policy_template_must_ignore_non_policy_templates(self): self.policy_template_processor_mock.has.assert_called_once_with(template_name) def test_is_policy_template_must_return_false_without_the_processor(self): - policy = { - "template_name": {"param1": "foo"} - } + policy = {"template_name": {"param1": "foo"}} - function_policies_obj = FunctionPolicies({}, None) # No policy template processor + function_policies_obj = FunctionPolicies({}, None) # No policy template processor self.assertFalse(function_policies_obj._is_policy_template(policy)) self.policy_template_processor_mock.has.assert_not_called() def test_is_intrinsic_if_must_return_true_for_if(self): - policy = { - "Fn::If": "some value" - } + policy = {"Fn::If": "some value"} self.assertTrue(is_intrinsic_if(policy)) def test_is_intrinsic_if_must_return_false_for_others(self): - too_many_keys = { - "Fn::If": "some value", - "Fn::And": "other value" - } + too_many_keys = {"Fn::If": "some value", "Fn::And": "other value"} - not_if = { - "Fn::Or": "some value" - } + not_if = {"Fn::Or": "some value"} self.assertFalse(is_intrinsic_if(too_many_keys)) self.assertFalse(is_intrinsic_if(not_if)) self.assertFalse(is_intrinsic_if(None)) def test_is_intrinsic_no_value_must_return_true_for_no_value(self): - policy = { - "Ref": "AWS::NoValue" - } + policy = {"Ref": "AWS::NoValue"} self.assertTrue(is_intrinsic_no_value(policy)) def test_is_intrinsic_no_value_must_return_false_for_other_value(self): - bad_key = { - "sRefs": "AWS::NoValue" - } + bad_key = {"sRefs": "AWS::NoValue"} - bad_value = { - "Ref": "SWA::NoValue" - } + bad_value = {"Ref": "SWA::NoValue"} - too_many_keys = { - "Ref": "AWS::NoValue", - "feR": "SWA::NoValue" - } + too_many_keys = {"Ref": "AWS::NoValue", "feR": "SWA::NoValue"} self.assertFalse(is_intrinsic_no_value(bad_key)) self.assertFalse(is_intrinsic_no_value(bad_value)) @@ -384,17 +303,11 @@ def test_is_intrinsic_no_value_must_return_false_for_other_value(self): self.assertFalse(is_intrinsic_no_value(too_many_keys)) def test_get_type_with_intrinsic_if_must_return_managed_policy_type(self): - managed_policy = { - "Fn::If": ["SomeCondition", "some managed policy arn", "other managed policy arn"] - } + managed_policy = {"Fn::If": ["SomeCondition", "some managed policy arn", "other managed policy arn"]} - no_value_if = { - "Fn::If": ["SomeCondition", {"Ref": "AWS::NoValue"}, "other managed policy arn"] - } + no_value_if = {"Fn::If": ["SomeCondition", {"Ref": "AWS::NoValue"}, "other managed policy arn"]} - no_value_else = { - "Fn::If": ["SomeCondition", "other managed policy arn", {"Ref": "AWS::NoValue"}] - } + no_value_else = {"Fn::If": ["SomeCondition", "other managed policy arn", {"Ref": "AWS::NoValue"}]} expected_managed_policy = PolicyTypes.MANAGED_POLICY @@ -407,13 +320,9 @@ def test_get_type_with_intrinsic_if_must_return_policy_statement_type(self): "Fn::If": ["SomeCondition", {"Statement": "then statement"}, {"Statement": "else statement"}] } - no_value_if = { - "Fn::If": ["SomeCondition", {"Ref": "AWS::NoValue"}, {"Statement": "else statement"}] - } + no_value_if = {"Fn::If": ["SomeCondition", {"Ref": "AWS::NoValue"}, {"Statement": "else statement"}]} - no_value_else = { - "Fn::If": ["SomeCondition", {"Statement": "then statement"}, {"Ref": "AWS::NoValue"}] - } + no_value_else = {"Fn::If": ["SomeCondition", {"Statement": "then statement"}, {"Ref": "AWS::NoValue"}]} expected_managed_policy = PolicyTypes.POLICY_STATEMENT self.assertTrue(expected_managed_policy, self.function_policies._get_type(policy_statement)) @@ -422,23 +331,14 @@ def test_get_type_with_intrinsic_if_must_return_policy_statement_type(self): def test_get_type_with_intrinsic_if_must_return_policy_template_type(self): policy_template = { - "Fn::If": [ "SomeCondition", - {"template_name_one": { "Param1": "foo"}}, - {"template_name_one": { "Param1": "foo"}} - ] - } - no_value_if = { - "Fn::If": [ "SomeCondition", - {"Ref": "AWS::NoValue"}, - {"template_name_one": { "Param1": "foo"}} - ] - } - no_value_else = { - "Fn::If": [ "SomeCondition", - {"template_name_one": { "Param1": "foo"}}, - {"Ref": "AWS::NoValue"} - ] + "Fn::If": [ + "SomeCondition", + {"template_name_one": {"Param1": "foo"}}, + {"template_name_one": {"Param1": "foo"}}, + ] } + no_value_if = {"Fn::If": ["SomeCondition", {"Ref": "AWS::NoValue"}, {"template_name_one": {"Param1": "foo"}}]} + no_value_else = {"Fn::If": ["SomeCondition", {"template_name_one": {"Param1": "foo"}}, {"Ref": "AWS::NoValue"}]} expected_managed_policy = PolicyTypes.POLICY_TEMPLATE self.policy_template_processor_mock.has.return_value = True @@ -449,24 +349,16 @@ def test_get_type_with_intrinsic_if_must_return_policy_template_type(self): self.assertTrue(expected_managed_policy, function_policies._get_type(no_value_else)) def test_get_type_with_intrinsic_if_must_raise_exception_for_bad_policy(self): - policy_too_few_values = { - "Fn::If": ["condition", "then"] - } + policy_too_few_values = {"Fn::If": ["condition", "then"]} - policy_too_many_values = { - "Fn::If": ["condition", "then", "else", "extra"] - } + policy_too_many_values = {"Fn::If": ["condition", "then", "else", "extra"]} self.assertRaises(InvalidTemplateException, self.function_policies._get_type, policy_too_few_values) self.assertRaises(InvalidTemplateException, self.function_policies._get_type, policy_too_many_values) def test_get_type_with_intrinsic_if_must_raise_exception_for_different_policy_types(self): - policy_one = { - "Fn::If": ["condition", "then", {"Statement": "else"}] - } - policy_two = { - "Fn::If": ["condition", {"Statement": "then"}, "else"] - } + policy_one = {"Fn::If": ["condition", "then", {"Statement": "else"}]} + policy_two = {"Fn::If": ["condition", {"Statement": "then"}, "else"]} self.assertRaises(InvalidTemplateException, self.function_policies._get_type, policy_one) - self.assertRaises(InvalidTemplateException, self.function_policies._get_type, policy_two) \ No newline at end of file + self.assertRaises(InvalidTemplateException, self.function_policies._get_type, policy_two) diff --git a/tests/model/test_sam_resources.py b/tests/model/test_sam_resources.py index d699c65884..d42d844af6 100644 --- a/tests/model/test_sam_resources.py +++ b/tests/model/test_sam_resources.py @@ -15,29 +15,22 @@ class TestCodeUri(TestCase): kwargs = { - 'intrinsics_resolver': IntrinsicsResolver({}), - 'event_resources': [], - 'managed_policy_map': { - "foo": "bar" - } + "intrinsics_resolver": IntrinsicsResolver({}), + "event_resources": [], + "managed_policy_map": {"foo": "bar"}, } - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_code_uri(self): function = SamFunction("foo") function.CodeUri = "s3://foobar/foo.zip" - cfnResources = function.to_cloudformation(**self.kwargs) generatedFunctionList = [x for x in cfnResources if isinstance(x, LambdaFunction)] self.assertEqual(generatedFunctionList.__len__(), 1) - self.assertEqual(generatedFunctionList[0].Code, { - "S3Key": "foo.zip", - "S3Bucket": "foobar", - }) - + self.assertEqual(generatedFunctionList[0].Code, {"S3Key": "foo.zip", "S3Bucket": "foobar"}) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_zip_file(self): function = SamFunction("foo") function.InlineCode = "hello world" @@ -45,45 +38,35 @@ def test_with_zip_file(self): cfnResources = function.to_cloudformation(**self.kwargs) generatedFunctionList = [x for x in cfnResources if isinstance(x, LambdaFunction)] self.assertEqual(generatedFunctionList.__len__(), 1) - self.assertEqual(generatedFunctionList[0].Code, { - "ZipFile": "hello world" - }) + self.assertEqual(generatedFunctionList[0].Code, {"ZipFile": "hello world"}) def test_with_no_code_uri_or_zipfile(self): function = SamFunction("foo") with pytest.raises(InvalidResourceException): function.to_cloudformation(**self.kwargs) + class TestAssumeRolePolicyDocument(TestCase): kwargs = { - 'intrinsics_resolver': IntrinsicsResolver({}), - 'event_resources': [], - 'managed_policy_map': { - "foo": "bar" - } + "intrinsics_resolver": IntrinsicsResolver({}), + "event_resources": [], + "managed_policy_map": {"foo": "bar"}, } - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_assume_role_policy_document(self): function = SamFunction("foo") function.CodeUri = "s3://foobar/foo.zip" assume_role_policy_document = { - 'Version': '2012-10-17', - 'Statement': [ - { - 'Action': [ - 'sts:AssumeRole' - ], - 'Effect': 'Allow', - 'Principal': { - 'Service': [ - 'lambda.amazonaws.com', - 'edgelambda.amazonaws.com' - ] - } - } - ] + "Version": "2012-10-17", + "Statement": [ + { + "Action": ["sts:AssumeRole"], + "Effect": "Allow", + "Principal": {"Service": ["lambda.amazonaws.com", "edgelambda.amazonaws.com"]}, + } + ], } function.AssumeRolePolicyDocument = assume_role_policy_document @@ -92,42 +75,31 @@ def test_with_assume_role_policy_document(self): generateFunctionVersion = [x for x in cfnResources if isinstance(x, IAMRole)] self.assertEqual(generateFunctionVersion[0].AssumeRolePolicyDocument, assume_role_policy_document) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_without_assume_role_policy_document(self): function = SamFunction("foo") function.CodeUri = "s3://foobar/foo.zip" assume_role_policy_document = { - 'Version': '2012-10-17', - 'Statement': [ - { - 'Action': [ - 'sts:AssumeRole' - ], - 'Effect': 'Allow', - 'Principal': { - 'Service': [ - 'lambda.amazonaws.com' - ] - } - } - ] + "Version": "2012-10-17", + "Statement": [ + {"Action": ["sts:AssumeRole"], "Effect": "Allow", "Principal": {"Service": ["lambda.amazonaws.com"]}} + ], } cfnResources = function.to_cloudformation(**self.kwargs) generateFunctionVersion = [x for x in cfnResources if isinstance(x, IAMRole)] self.assertEqual(generateFunctionVersion[0].AssumeRolePolicyDocument, assume_role_policy_document) + class TestVersionDescription(TestCase): kwargs = { - 'intrinsics_resolver': IntrinsicsResolver({}), - 'event_resources': [], - 'managed_policy_map': { - "foo": "bar" - } + "intrinsics_resolver": IntrinsicsResolver({}), + "event_resources": [], + "managed_policy_map": {"foo": "bar"}, } - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_version_description(self): function = SamFunction("foo") test_description = "foobar" @@ -140,16 +112,15 @@ def test_with_version_description(self): generateFunctionVersion = [x for x in cfnResources if isinstance(x, LambdaVersion)] self.assertEqual(generateFunctionVersion[0].Description, test_description) + class TestOpenApi(TestCase): kwargs = { - 'intrinsics_resolver': IntrinsicsResolver({}), - 'event_resources': [], - 'managed_policy_map': { - "foo": "bar" - } + "intrinsics_resolver": IntrinsicsResolver({}), + "event_resources": [], + "managed_policy_map": {"foo": "bar"}, } - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_open_api_3_no_stage(self): api = SamApi("foo") api.OpenApiVersion = "3.0" @@ -160,7 +131,7 @@ def test_with_open_api_3_no_stage(self): self.assertEqual(deployment.__len__(), 1) self.assertEqual(deployment[0].StageName, None) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_open_api_2_no_stage(self): api = SamApi("foo") api.OpenApiVersion = "3.0" @@ -171,14 +142,14 @@ def test_with_open_api_2_no_stage(self): self.assertEqual(deployment.__len__(), 1) self.assertEqual(deployment[0].StageName, None) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_open_api_bad_value(self): api = SamApi("foo") api.OpenApiVersion = "5.0" with pytest.raises(InvalidResourceException): api.to_cloudformation(**self.kwargs) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_swagger_no_stage(self): api = SamApi("foo") @@ -188,16 +159,15 @@ def test_with_swagger_no_stage(self): self.assertEqual(deployment.__len__(), 1) self.assertEqual(deployment[0].StageName, "Stage") + class TestApiTags(TestCase): kwargs = { - 'intrinsics_resolver': IntrinsicsResolver({}), - 'event_resources': [], - 'managed_policy_map': { - "foo": "bar" - } + "intrinsics_resolver": IntrinsicsResolver({}), + "event_resources": [], + "managed_policy_map": {"foo": "bar"}, } - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_no_tags(self): api = SamApi("foo") api.Tags = {} @@ -208,17 +178,13 @@ def test_with_no_tags(self): self.assertEqual(deployment.__len__(), 1) self.assertEqual(deployment[0].Tags, []) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_with_tags(self): api = SamApi("foo") - api.Tags = { - 'MyKey': 'MyValue' - } + api.Tags = {"MyKey": "MyValue"} resources = api.to_cloudformation(**self.kwargs) deployment = [x for x in resources if isinstance(x, ApiGatewayStage)] self.assertEqual(deployment.__len__(), 1) - self.assertEqual(deployment[0].Tags, [ - {'Key': 'MyKey', 'Value': 'MyValue'} - ]) + self.assertEqual(deployment[0].Tags, [{"Key": "MyKey", "Value": "MyValue"}]) diff --git a/tests/openapi/test_openapi.py b/tests/openapi/test_openapi.py index ffcbc20744..95eceb135a 100644 --- a/tests/openapi/test_openapi.py +++ b/tests/openapi/test_openapi.py @@ -7,7 +7,7 @@ from samtranslator.model.exceptions import InvalidDocumentException _X_INTEGRATION = "x-amazon-apigateway-integration" -_X_ANY_METHOD = 'x-amazon-apigateway-any-method' +_X_ANY_METHOD = "x-amazon-apigateway-any-method" # TODO: add a case for swagger and make sure it fails @@ -18,10 +18,7 @@ def test_must_raise_on_valid_swagger(self): valid_swagger = { "swagger": "2.0", # "openapi": "2.1.0" - "paths": { - "/foo": {}, - "/bar": {} - } + "paths": {"/foo": {}, "/bar": {}}, } # missing openapi key word with self.assertRaises(ValueError): OpenApiEditor(valid_swagger) @@ -33,13 +30,7 @@ def test_must_raise_on_invalid_openapi(self): OpenApiEditor(invalid_openapi) def test_must_succeed_on_valid_openapi(self): - valid_openapi = { - "openapi": "3.0.1", - "paths": { - "/foo": {}, - "/bar": {} - } - } + valid_openapi = {"openapi": "3.0.1", "paths": {"/foo": {}, "/bar": {}}} editor = OpenApiEditor(valid_openapi) self.assertIsNotNone(editor) @@ -47,37 +38,19 @@ def test_must_succeed_on_valid_openapi(self): self.assertEqual(editor.paths, {"/foo": {}, "/bar": {}}) def test_must_fail_on_invalid_openapi_version(self): - invalid_openapi = { - "openapi": "2.3.0", - "paths": { - "/foo": {}, - "/bar": {} - } - } + invalid_openapi = {"openapi": "2.3.0", "paths": {"/foo": {}, "/bar": {}}} with self.assertRaises(ValueError): OpenApiEditor(invalid_openapi) def test_must_fail_on_invalid_openapi_version_2(self): - invalid_openapi = { - "openapi": "3.1.1.1", - "paths": { - "/foo": {}, - "/bar": {} - } - } + invalid_openapi = {"openapi": "3.1.1.1", "paths": {"/foo": {}, "/bar": {}}} with self.assertRaises(ValueError): OpenApiEditor(invalid_openapi) def test_must_succeed_on_valid_openapi3(self): - valid_openapi = { - "openapi": "3.0.1", - "paths": { - "/foo": {}, - "/bar": {} - } - } + valid_openapi = {"openapi": "3.0.1", "paths": {"/foo": {}, "/bar": {}}} editor = OpenApiEditor(valid_openapi) self.assertIsNotNone(editor) @@ -86,21 +59,14 @@ def test_must_succeed_on_valid_openapi3(self): class TestOpenApiEditor_has_path(TestCase): - def setUp(self): self.openapi = { "openapi": "3.0.1", "paths": { - "/foo": { - "get": {}, - "somemethod": {} - }, - "/bar": { - "post": {}, - _X_ANY_METHOD: {} - }, - "badpath": "string value" - } + "/foo": {"get": {}, "somemethod": {}}, + "/bar": {"post": {}, _X_ANY_METHOD: {}}, + "badpath": "string value", + }, } self.editor = OpenApiEditor(self.openapi) @@ -146,48 +112,19 @@ def test_must_not_fail_on_bad_path(self): class TestOpenApiEditor_has_integration(TestCase): - def setUp(self): self.openapi = { "openapi": "3.0.1", "paths": { "/foo": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - }, - "post": { - "Fn::If": [ - "Condition", - { - _X_INTEGRATION: { - "a": "b" - } - }, - {"Ref": "AWS::NoValue"} - ] - }, - "delete": { - "Fn::If": [ - "Condition", - {"Ref": "AWS::NoValue"}, - { - _X_INTEGRATION: { - "a": "b" - } - } - ] - }, - "somemethod": { - "foo": "value", - }, - "emptyintegration": { - _X_INTEGRATION: {} - }, - "badmethod": "string value" - }, - } + "get": {_X_INTEGRATION: {"a": "b"}}, + "post": {"Fn::If": ["Condition", {_X_INTEGRATION: {"a": "b"}}, {"Ref": "AWS::NoValue"}]}, + "delete": {"Fn::If": ["Condition", {"Ref": "AWS::NoValue"}, {_X_INTEGRATION: {"a": "b"}}]}, + "somemethod": {"foo": "value"}, + "emptyintegration": {_X_INTEGRATION: {}}, + "badmethod": "string value", + } + }, } self.editor = OpenApiEditor(self.openapi) @@ -212,32 +149,27 @@ def test_must_handle_bad_value_for_method(self): class TestOpenApiEditor_add_path(TestCase): - def setUp(self): self.original_openapi = { "openapi": "3.0.1", - "paths": { - "/foo": { - "get": {"a": "b"} - }, - "/bar": {}, - "/badpath": "string value" - } + "paths": {"/foo": {"get": {"a": "b"}}, "/bar": {}, "/badpath": "string value"}, } self.editor = OpenApiEditor(self.original_openapi) - @parameterized.expand([ - param("/new", "get", "new path, new method"), - param("/foo", "new method", "existing path, new method"), - param("/bar", "get", "existing path, new method"), - ]) + @parameterized.expand( + [ + param("/new", "get", "new path, new method"), + param("/foo", "new method", "existing path, new method"), + param("/bar", "get", "existing path, new method"), + ] + ) def test_must_add_new_path_and_method(self, path, method, case): self.assertFalse(self.editor.has_path(path, method)) self.editor.add_path(path, method) - self.assertTrue(self.editor.has_path(path, method), "must add for "+case) + self.assertTrue(self.editor.has_path(path, method), "must add for " + case) self.assertEqual(self.editor.openapi["paths"][path][method], {}) def test_must_raise_non_dict_path_values(self): @@ -264,28 +196,14 @@ def test_must_skip_existing_path(self): class TestOpenApiEditor_add_lambda_integration(TestCase): - def setUp(self): self.original_openapi = { "openapi": "3.0.1", "paths": { - "/foo": { - "post": { - "a": [1, 2, "b"], - "responses": { - "something": "is already here" - } - } - }, - "/bar": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - } - }, - } + "/foo": {"post": {"a": [1, 2, "b"], "responses": {"something": "is already here"}}}, + "/bar": {"get": {_X_INTEGRATION: {"a": "b"}}}, + }, } self.editor = OpenApiEditor(self.original_openapi) @@ -300,8 +218,8 @@ def test_must_add_new_integration_to_new_path(self): "type": "aws_proxy", "httpMethod": "POST", "payloadFormatVersion": "1.0", - "uri": integration_uri - } + "uri": integration_uri, + }, } self.editor.add_lambda_integration(path, method, integration_uri) @@ -324,20 +242,10 @@ def test_must_add_new_integration_with_conditions_to_new_path(self): "type": "aws_proxy", "httpMethod": "POST", "payloadFormatVersion": "1.0", - "uri": { - "Fn::If": [ - "condition", - integration_uri, - { - "Ref": "AWS::NoValue" - } - ] - } - } + "uri": {"Fn::If": ["condition", integration_uri, {"Ref": "AWS::NoValue"}]}, + }, }, - { - "Ref": "AWS::NoValue" - } + {"Ref": "AWS::NoValue"}, ] } @@ -354,19 +262,15 @@ def test_must_add_new_integration_to_existing_path(self): expected = { # Current values present in the dictionary *MUST* be preserved "a": [1, 2, "b"], - # Responses key must be untouched - "responses": { - "something": "is already here" - }, - + "responses": {"something": "is already here"}, # New values must be added _X_INTEGRATION: { "type": "aws_proxy", "httpMethod": "POST", "payloadFormatVersion": "1.0", - "uri": integration_uri - } + "uri": integration_uri, + }, } # Just make sure test is working on an existing path @@ -379,17 +283,9 @@ def test_must_add_new_integration_to_existing_path(self): class TestOpenApiEditor_iter_on_path(TestCase): - def setUp(self): - self.original_openapi = { - "openapi": "3.0.1", - "paths": { - "/foo": {}, - "/bar": {}, - "/baz": "some value" - } - } + self.original_openapi = {"openapi": "3.0.1", "paths": {"/foo": {}, "/bar": {}, "/baz": "some value"}} self.editor = OpenApiEditor(self.original_openapi) @@ -402,27 +298,24 @@ def test_must_iterate_on_paths(self): class TestOpenApiEditor_normalize_method_name(TestCase): - - @parameterized.expand([ - param("GET", "get", "must lowercase"), - param("PoST", "post", "must lowercase"), - param("ANY", _X_ANY_METHOD, "must convert any method"), - param(None, None, "must skip empty values"), - param({"a": "b"}, {"a": "b"}, "must skip non-string values"), - param([1, 2], [1, 2], "must skip non-string values"), - ]) + @parameterized.expand( + [ + param("GET", "get", "must lowercase"), + param("PoST", "post", "must lowercase"), + param("ANY", _X_ANY_METHOD, "must convert any method"), + param(None, None, "must skip empty values"), + param({"a": "b"}, {"a": "b"}, "must skip non-string values"), + param([1, 2], [1, 2], "must skip non-string values"), + ] + ) def test_must_normalize(self, input, expected, msg): self.assertEqual(expected, OpenApiEditor._normalize_method_name(input), msg) class TestOpenApiEditor_openapi_property(TestCase): - def test_must_return_copy_of_openapi(self): - input = { - "openapi": "3.0.1", - "paths": {} - } + input = {"openapi": "3.0.1", "paths": {}} editor = OpenApiEditor(input) self.assertEqual(input, editor.openapi) # They are equal in content @@ -435,63 +328,47 @@ def test_must_return_copy_of_openapi(self): class TestOpenApiEditor_is_valid(TestCase): - - @parameterized.expand([ - param(OpenApiEditor.gen_skeleton()), - - # Dict can contain any other unrecognized properties - param({"openapi": "3.1.1", "paths": {}, "foo": "bar", "baz": "bar"}) - # TODO check and update the regex accordingly - # Fails for this: param({"openapi": "3.1.10", "paths": {}, "foo": "bar", "baz": "bar"}) - ]) + @parameterized.expand( + [ + param(OpenApiEditor.gen_skeleton()), + # Dict can contain any other unrecognized properties + param({"openapi": "3.1.1", "paths": {}, "foo": "bar", "baz": "bar"}) + # TODO check and update the regex accordingly + # Fails for this: param({"openapi": "3.1.10", "paths": {}, "foo": "bar", "baz": "bar"}) + ] + ) def test_must_work_on_valid_values(self, openapi): self.assertTrue(OpenApiEditor.is_valid(openapi)) - @parameterized.expand([ - ({}, "empty dictionary"), - ([1, 2, 3], "array data type"), - ({"paths": {}}, "missing openapi property"), - ({"openapi": "hello"}, "missing paths property"), - ({"openapi": "hello", "paths": [1, 2, 3]}, "array value for paths property"), - ]) + @parameterized.expand( + [ + ({}, "empty dictionary"), + ([1, 2, 3], "array data type"), + ({"paths": {}}, "missing openapi property"), + ({"openapi": "hello"}, "missing paths property"), + ({"openapi": "hello", "paths": [1, 2, 3]}, "array value for paths property"), + ] + ) def test_must_fail_for_invalid_values(self, data, case): self.assertFalse(OpenApiEditor.is_valid(data), "openapi dictionary with {} must not be valid".format(case)) # TODO this needs to be updated with OIDC auth - authorization scopes and anything else that needs testing the swagger class TestOpenApiEditor_add_auth(TestCase): - def setUp(self): self.original_openapi = { "openapi": "3.0.1", "paths": { - "/foo": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - }, - "post":{ - _X_INTEGRATION: { - "a": "b" - } - } - }, - "/bar": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - } - }, - } + "/foo": {"get": {_X_INTEGRATION: {"a": "b"}}, "post": {_X_INTEGRATION: {"a": "b"}}}, + "/bar": {"get": {_X_INTEGRATION: {"a": "b"}}}, + }, } self.editor = OpenApiEditor(self.original_openapi) -class TestOpenApiEditor_get_integration_function(TestCase): +class TestOpenApiEditor_get_integration_function(TestCase): def setUp(self): self.original_openapi = { @@ -500,43 +377,41 @@ def setUp(self): "$default": { "x-amazon-apigateway-any-method": { "Fn::If": [ - "condition", - { - "security": [ + "condition", { - "OpenIdAuth": [ - "scope1", - "scope2" - ] + "security": [{"OpenIdAuth": ["scope1", "scope2"]}], + "isDefaultRoute": True, + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::If": [ + "condition", + { + "Fn::Sub": "arn:${AWS::Partition}:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HttpApiFunction.Arn}/invocations" + }, + {"Ref": "AWS::NoValue"}, + ] + }, + "payloadFormatVersion": "1.0", + }, + "responses": {}, }, - ], - "isDefaultRoute": True, - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::If": [ - "condition", - {"Fn::Sub": "arn:${AWS::Partition}:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${HttpApiFunction.Arn}/invocations"}, - {"Ref": "AWS::NoValue"} - ] - }, - "payloadFormatVersion": "1.0" - }, - "responses": {} - }, - {"Ref": "AWS::NoValue"} + {"Ref": "AWS::NoValue"}, ] } - }, + }, "/bar": {}, - "/badpath": "string value" - } + "/badpath": "string value", + }, } self.editor = OpenApiEditor(self.original_openapi) def test_must_get_integration_function_if_exists(self): - self.assertEqual(self.editor.get_integration_function_logical_id(OpenApiEditor._DEFAULT_PATH, OpenApiEditor._X_ANY_METHOD), "HttpApiFunction") + self.assertEqual( + self.editor.get_integration_function_logical_id(OpenApiEditor._DEFAULT_PATH, OpenApiEditor._X_ANY_METHOD), + "HttpApiFunction", + ) self.assertFalse(self.editor.get_integration_function_logical_id("/bar", "get")) diff --git a/tests/plugins/api/test_default_definition_body_plugin.py b/tests/plugins/api/test_default_definition_body_plugin.py index 67ad86c3f3..a6e519675b 100644 --- a/tests/plugins/api/test_default_definition_body_plugin.py +++ b/tests/plugins/api/test_default_definition_body_plugin.py @@ -8,7 +8,6 @@ class TestDefaultDefinitionBodyPlugin_init(TestCase): - def setUp(self): self.plugin = DefaultDefinitionBodyPlugin() @@ -23,7 +22,6 @@ def test_plugin_must_be_instance_of_base_plugin_class(self): class TestDefaultDefinitionBodyPlugin_on_before_transform_template(TestCase): - def setUp(self): self.plugin = DefaultDefinitionBodyPlugin() diff --git a/tests/plugins/api/test_implicit_api_plugin.py b/tests/plugins/api/test_implicit_api_plugin.py index f7acc39a11..d1f55ce596 100644 --- a/tests/plugins/api/test_implicit_api_plugin.py +++ b/tests/plugins/api/test_implicit_api_plugin.py @@ -8,8 +8,8 @@ IMPLICIT_API_LOGICAL_ID = "ServerlessRestApi" -class TestImplicitRestApiPluginEndtoEnd(TestCase): +class TestImplicitRestApiPluginEndtoEnd(TestCase): def test_must_work_for_single_function(self): """ Test the basic case of one function with a few API events @@ -45,7 +45,6 @@ def test_must_ignore_template_without_function(self): class TestImplicitRestApiPlugin_init(TestCase): - def setUp(self): self.plugin = ImplicitRestApiPlugin() @@ -60,7 +59,6 @@ def test_plugin_must_be_instance_of_base_plugin_class(self): class TestImplicitRestApiPlugin_on_before_transform_template(TestCase): - def setUp(self): self.plugin = ImplicitRestApiPlugin() @@ -96,15 +94,16 @@ def test_must_process_functions(self, SamTemplateMock): sam_template.iterate.assert_any_call("AWS::Serverless::Api") self.plugin._get_api_events.assert_has_calls([call(function1), call(function2), call(function3)]) - self.plugin._process_api_events.assert_has_calls([ - call(function1, ["event1", "event2"], sam_template, None), - call(function2, ["event1", "event2"], sam_template, None), - call(function3, ["event1", "event2"], sam_template, None), - ]) + self.plugin._process_api_events.assert_has_calls( + [ + call(function1, ["event1", "event2"], sam_template, None), + call(function2, ["event1", "event2"], sam_template, None), + call(function3, ["event1", "event2"], sam_template, None), + ] + ) self.plugin._maybe_remove_implicit_api.assert_called_with(sam_template) - @patch("samtranslator.plugins.api.implicit_api_plugin.SamTemplate") def test_must_skip_functions_without_events(self, SamTemplateMock): @@ -155,12 +154,16 @@ def test_must_skip_without_functions(self, SamTemplateMock): def test_must_collect_errors_and_raise_on_invalid_events(self, SamTemplateMock): template_dict = {"a": "b"} - function_resources = [("id1", SamResource({"Type": "AWS::Serverless::Function"})), - ("id2", SamResource({"Type": "AWS::Serverless::Function"})), - ("id3", SamResource({"Type": "AWS::Serverless::Function"}))] - api_event_errors = [InvalidEventException("eventid1", "msg"), - InvalidEventException("eventid3", "msg"), - InvalidEventException("eventid3", "msg")] + function_resources = [ + ("id1", SamResource({"Type": "AWS::Serverless::Function"})), + ("id2", SamResource({"Type": "AWS::Serverless::Function"})), + ("id3", SamResource({"Type": "AWS::Serverless::Function"})), + ] + api_event_errors = [ + InvalidEventException("eventid1", "msg"), + InvalidEventException("eventid3", "msg"), + InvalidEventException("eventid3", "msg"), + ] sam_template = Mock() SamTemplateMock.return_value = sam_template @@ -191,7 +194,6 @@ def test_must_collect_errors_and_raise_on_invalid_events(self, SamTemplateMock): class TestImplicitRestApiPlugin_get_api_events(TestCase): - def setUp(self): self.plugin = ImplicitRestApiPlugin() @@ -199,42 +201,17 @@ def test_must_get_all_api_events_in_function(self): properties = { "Events": { - "Api1": { - "Type": "Api", - "Properties": { - "a": "b" - } - }, - - "Api2": { - "Type": "Api", - "Properties": {"c": "d"} - }, - - "Other": { - "Type": "Something", - "Properties": { - } - } + "Api1": {"Type": "Api", "Properties": {"a": "b"}}, + "Api2": {"Type": "Api", "Properties": {"c": "d"}}, + "Other": {"Type": "Something", "Properties": {}}, } } - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": properties - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": properties}) expected = { - "Api1": { - "Type": "Api", - "Properties": { - "a": "b" - } - }, - "Api2": { - "Type": "Api", - "Properties": {"c": "d"} - } + "Api1": {"Type": "Api", "Properties": {"a": "b"}}, + "Api2": {"Type": "Api", "Properties": {"c": "d"}}, } result = self.plugin._get_api_events(function) self.assertEqual(expected, result) @@ -243,29 +220,13 @@ def test_must_work_with_no_api_events(self): properties = { "Events": { - "Event1": { - "Type": "some", - "Properties": { - "a": "b" - } - }, - - "EventWithNoType": { - "Properties": {"c": "d"} - }, - - "Event3": { - "Type": "Something", - "Properties": { - } - } + "Event1": {"Type": "some", "Properties": {"a": "b"}}, + "EventWithNoType": {"Properties": {"c": "d"}}, + "Event3": {"Type": "Something", "Properties": {}}, } } - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": properties - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": properties}) expected = {} result = self.plugin._get_api_events(function) @@ -273,12 +234,7 @@ def test_must_work_with_no_api_events(self): def test_must_skip_with_bad_events_structure(self): - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": "must not be string" - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": "must not be string"}}) expected = {} result = self.plugin._get_api_events(function) @@ -286,12 +242,7 @@ def test_must_skip_with_bad_events_structure(self): def test_must_skip_if_no_events_property(self): - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "no": "events" - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"no": "events"}}) expected = {} result = self.plugin._get_api_events(function) @@ -299,29 +250,19 @@ def test_must_skip_if_no_events_property(self): def test_must_skip_if_no_property_dictionary(self): - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": "bad value" - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": "bad value"}) expected = {} result = self.plugin._get_api_events(function) self.assertEqual(expected, result) def test_must_return_reference_to_event_dict(self): - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": { - "Api1": { - "Type": "Api", - "Properties": { - "a": "b" - } - } - } + function = SamResource( + { + "Type": SamResourceType.Function.value, + "Properties": {"Events": {"Api1": {"Type": "Api", "Properties": {"a": "b"}}}}, } - }) + ) result = self.plugin._get_api_events(function) @@ -331,18 +272,13 @@ def test_must_return_reference_to_event_dict(self): def test_must_skip_if_function_is_not_valid(self): - function = SamResource({ - # NOT a SAM resource - "Type": "AWS::Lambda::Function", - "Properties": { - "Events": { - "Api1": { - "Type": "Api", - "Properties": {} - } - } + function = SamResource( + { + # NOT a SAM resource + "Type": "AWS::Lambda::Function", + "Properties": {"Events": {"Api1": {"Type": "Api", "Properties": {}}}}, } - }) + ) expected = {} result = self.plugin._get_api_events(function) @@ -350,7 +286,6 @@ def test_must_skip_if_function_is_not_valid(self): class TestImplicitRestApiPlugin_process_api_events(TestCase): - def setUp(self): self.plugin = ImplicitRestApiPlugin() self.plugin._add_api_to_swagger = Mock() @@ -358,90 +293,48 @@ def setUp(self): def test_must_work_with_api_events(self): api_events = { - "Api1": { - "Type": "Api", - "Properties": { - "Path": "/", - "Method": "GET" - } - }, - "Api2": { - "Type": "Api", - "Properties": { - "Path": "/foo", - "Method": "POST" - } - } + "Api1": {"Type": "Api", "Properties": {"Path": "/", "Method": "GET"}}, + "Api2": {"Type": "Api", "Properties": {"Path": "/foo", "Method": "POST"}}, } template = Mock() function_events_mock = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": function_events_mock - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": function_events_mock}}) function_events_mock.update = Mock() self.plugin._process_api_events(function, api_events, template) - self.plugin._add_implicit_api_id_if_necessary.assert_has_calls([ - call({"Path": "/", "Method": "GET"}), - call({"Path": "/foo", "Method": "POST"}), - ]) + self.plugin._add_implicit_api_id_if_necessary.assert_has_calls( + [call({"Path": "/", "Method": "GET"}), call({"Path": "/foo", "Method": "POST"})] + ) - self.plugin._add_api_to_swagger.assert_has_calls([ - call("Api1", {"Path": "/", "Method": "GET"}, template), - call("Api2", {"Path": "/foo", "Method": "POST"}, template), - ]) + self.plugin._add_api_to_swagger.assert_has_calls( + [ + call("Api1", {"Path": "/", "Method": "GET"}, template), + call("Api2", {"Path": "/foo", "Method": "POST"}, template), + ] + ) function_events_mock.update.assert_called_with(api_events) def test_must_verify_expected_keys_exist(self): - api_events = { - "Api1": { - "Type": "Api", - "Properties": { - "Path": "/", - "Methid": "POST" - } - } - } + api_events = {"Api1": {"Type": "Api", "Properties": {"Path": "/", "Methid": "POST"}}} template = Mock() function_events_mock = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": function_events_mock - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": function_events_mock}}) function_events_mock.update = Mock() with self.assertRaises(InvalidEventException) as context: self.plugin._process_api_events(function, api_events, template) def test_must_verify_method_is_string(self): - api_events = { - "Api1": { - "Type": "Api", - "Properties": { - "Path": "/", - "Method": ["POST"] - } - } - } + api_events = {"Api1": {"Type": "Api", "Properties": {"Path": "/", "Method": ["POST"]}}} template = Mock() function_events_mock = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": function_events_mock - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": function_events_mock}}) function_events_mock.update = Mock() with self.assertRaises(InvalidEventException) as context: @@ -454,43 +347,25 @@ def test_must_verify_rest_api_id_is_string(self): "Properties": { "Path": "/", "Method": ["POST"], - "RestApiId": {"Fn::ImportValue": {"Fn::Sub": {"ApiName"}}} - } + "RestApiId": {"Fn::ImportValue": {"Fn::Sub": {"ApiName"}}}, + }, } } template = Mock() function_events_mock = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": function_events_mock - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": function_events_mock}}) function_events_mock.update = Mock() with self.assertRaises(InvalidEventException) as context: self.plugin._process_api_events(function, api_events, template) def test_must_verify_path_is_string(self): - api_events = { - "Api1": { - "Type": "Api", - "Properties": { - "Path": ["/"], - "Method": "POST" - } - } - } + api_events = {"Api1": {"Type": "Api", "Properties": {"Path": ["/"], "Method": "POST"}}} template = Mock() function_events_mock = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": function_events_mock - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": function_events_mock}}) function_events_mock.update = Mock() with self.assertRaises(InvalidEventException) as context: @@ -498,36 +373,16 @@ def test_must_verify_path_is_string(self): def test_must_skip_events_without_properties(self): - api_events = { - "Api1": { - "Type": "Api" - }, - "Api2": { - "Type": "Api", - "Properties": { - "Path": "/", - "Method": "GET" - } - } - } + api_events = {"Api1": {"Type": "Api"}, "Api2": {"Type": "Api", "Properties": {"Path": "/", "Method": "GET"}}} template = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": api_events - } - }) + function = SamResource({"Type": SamResourceType.Function.value, "Properties": {"Events": api_events}}) self.plugin._process_api_events(function, api_events, template) - self.plugin._add_implicit_api_id_if_necessary.assert_has_calls([ - call({"Path": "/", "Method": "GET"}), - ]) + self.plugin._add_implicit_api_id_if_necessary.assert_has_calls([call({"Path": "/", "Method": "GET"})]) - self.plugin._add_api_to_swagger.assert_has_calls([ - call("Api2", {"Path": "/", "Method": "GET"}, template), - ]) + self.plugin._add_api_to_swagger.assert_has_calls([call("Api2", {"Path": "/", "Method": "GET"}, template)]) def test_must_retain_side_effect_of_modifying_events(self): """ @@ -535,34 +390,23 @@ def test_must_retain_side_effect_of_modifying_events(self): """ api_events = { - "Api1": { - "Type": "Api", - "Properties": { - "Path": "/", - "Method": "get" - } - }, - "Api2": { - "Type": "Api", - "Properties": { - "Path": "/foo", - "Method": "post" - } - } + "Api1": {"Type": "Api", "Properties": {"Path": "/", "Method": "get"}}, + "Api2": {"Type": "Api", "Properties": {"Path": "/foo", "Method": "post"}}, } template = Mock() - function = SamResource({ - "Type": SamResourceType.Function.value, - "Properties": { - "Events": { - "Api1": "Intentionally setting this value to a string for testing. " - "This should be replaced by API Event after processing", - - "Api2": "must be replaced" - } + function = SamResource( + { + "Type": SamResourceType.Function.value, + "Properties": { + "Events": { + "Api1": "Intentionally setting this value to a string for testing. " + "This should be replaced by API Event after processing", + "Api2": "must be replaced", + } + }, } - }) + ) def add_key_to_event(event_properties): event_properties["Key"] = "Value" @@ -577,81 +421,71 @@ def add_key_to_event(event_properties): self.assertEqual(api_events["Api2"]["Properties"], {"Path": "/foo", "Method": "post", "Key": "Value"}) # Every Event object inside the SamResource class must be entirely replaced by input api_events with side effect - self.assertEqual(function.properties["Events"]["Api1"]["Properties"], {"Path": "/", "Method": "get", "Key": "Value"}) - self.assertEqual(function.properties["Events"]["Api2"]["Properties"], {"Path": "/foo", "Method": "post", "Key": "Value"}) + self.assertEqual( + function.properties["Events"]["Api1"]["Properties"], {"Path": "/", "Method": "get", "Key": "Value"} + ) + self.assertEqual( + function.properties["Events"]["Api2"]["Properties"], {"Path": "/foo", "Method": "post", "Key": "Value"} + ) # Subsequent calls must be made with the side effect. This is important. - self.plugin._add_api_to_swagger.assert_has_calls([ - call("Api1", - # Side effects should be visible here - {"Path": "/", "Method": "get", "Key": "Value"}, - template), - call("Api2", - # Side effects should be visible here - {"Path": "/foo", "Method": "post", "Key": "Value"}, - template), - ]) + self.plugin._add_api_to_swagger.assert_has_calls( + [ + call( + "Api1", + # Side effects should be visible here + {"Path": "/", "Method": "get", "Key": "Value"}, + template, + ), + call( + "Api2", + # Side effects should be visible here + {"Path": "/foo", "Method": "post", "Key": "Value"}, + template, + ), + ] + ) class TestImplicitRestApiPlugin_add_implicit_api_id_if_necessary(TestCase): - def setUp(self): self.plugin = ImplicitRestApiPlugin() def test_must_add_if_not_present(self): - input = { - "a": "b" - } + input = {"a": "b"} - expected = { - "a": "b", - "RestApiId": {"Ref": IMPLICIT_API_LOGICAL_ID} - } + expected = {"a": "b", "RestApiId": {"Ref": IMPLICIT_API_LOGICAL_ID}} self.plugin._add_implicit_api_id_if_necessary(input) self.assertEqual(input, expected) - def test_must_skip_if_present(self): - input = { - "a": "b", - "RestApiId": "Something" - } + input = {"a": "b", "RestApiId": "Something"} - expected = { - "a": "b", - "RestApiId": "Something" - } + expected = {"a": "b", "RestApiId": "Something"} self.plugin._add_implicit_api_id_if_necessary(input) self.assertEqual(input, expected) class TestImplicitRestApiPlugin_add_api_to_swagger(TestCase): - def setUp(self): self.plugin = ImplicitRestApiPlugin() @patch("samtranslator.plugins.api.implicit_rest_api_plugin.SwaggerEditor") def test_must_add_path_method_to_swagger_of_api_resource(self, SwaggerEditorMock): event_id = "id" - properties = { - "RestApiId": {"Ref": "restid"}, - "Path": "/hello", - "Method": "GET" - } + properties = {"RestApiId": {"Ref": "restid"}, "Path": "/hello", "Method": "GET"} original_swagger = {"this": "is", "valid": "swagger"} updated_swagger = "updated swagger" - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "__MANAGE_SWAGGER": True, - "DefinitionBody": original_swagger, - "a": "b" + mock_api = SamResource( + { + "Type": "AWS::Serverless::Api", + "Properties": {"__MANAGE_SWAGGER": True, "DefinitionBody": original_swagger, "a": "b"}, } - }) + ) SwaggerEditorMock.is_valid = Mock() SwaggerEditorMock.is_valid.return_value = True @@ -668,7 +502,7 @@ def test_must_add_path_method_to_swagger_of_api_resource(self, SwaggerEditorMock self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_called_with(original_swagger) - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") editor_mock.add_path("/hello", "GET") template_mock.set.assert_called_with("restid", mock_api) self.assertEqual(mock_api.properties["DefinitionBody"], updated_swagger) @@ -680,18 +514,16 @@ def test_must_work_with_rest_api_id_as_string(self, SwaggerEditorMock): # THIS IS A STRING, not a {"Ref"} "RestApiId": "restid", "Path": "/hello", - "Method": "GET" + "Method": "GET", } original_swagger = {"this": "is", "valid": "swagger"} updated_swagger = "updated swagger" - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "__MANAGE_SWAGGER": True, - "DefinitionBody": original_swagger, - "a": "b" + mock_api = SamResource( + { + "Type": "AWS::Serverless::Api", + "Properties": {"__MANAGE_SWAGGER": True, "DefinitionBody": original_swagger, "a": "b"}, } - }) + ) SwaggerEditorMock.is_valid = Mock() SwaggerEditorMock.is_valid.return_value = True @@ -708,18 +540,14 @@ def test_must_work_with_rest_api_id_as_string(self, SwaggerEditorMock): self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_called_with(original_swagger) - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") editor_mock.add_path("/hello", "GET") template_mock.set.assert_called_with("restid", mock_api) self.assertEqual(mock_api.properties["DefinitionBody"], updated_swagger) def test_must_raise_when_api_is_not_found(self): event_id = "id" - properties = { - "RestApiId": "unknown", - "Path": "/hello", - "Method": "GET" - } + properties = {"RestApiId": "unknown", "Path": "/hello", "Method": "GET"} template_mock = Mock() template_mock.get = Mock() @@ -732,11 +560,7 @@ def test_must_raise_when_api_is_not_found(self): def test_must_raise_when_api_id_is_intrinsic(self): event_id = "id" - properties = { - "RestApiId": {"Fn::GetAtt": "restapi"}, - "Path": "/hello", - "Method": "GET" - } + properties = {"RestApiId": {"Fn::GetAtt": "restapi"}, "Path": "/hello", "Method": "GET"} template_mock = Mock() template_mock.get = Mock() @@ -751,19 +575,11 @@ def test_must_raise_when_api_id_is_intrinsic(self): def test_must_skip_invalid_swagger(self, SwaggerEditorMock): event_id = "id" - properties = { - "RestApiId": {"Ref": "restid"}, - "Path": "/hello", - "Method": "GET" - } + properties = {"RestApiId": {"Ref": "restid"}, "Path": "/hello", "Method": "GET"} original_swagger = {"this": "is", "valid": "swagger"} - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": original_swagger, - "a": "b" - } - }) + mock_api = SamResource( + {"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": original_swagger, "a": "b"}} + ) SwaggerEditorMock.is_valid = Mock() SwaggerEditorMock.is_valid.return_value = False @@ -777,7 +593,7 @@ def test_must_skip_invalid_swagger(self, SwaggerEditorMock): self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_called_with(original_swagger) - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") SwaggerEditorMock.assert_not_called() template_mock.set.assert_not_called() @@ -785,17 +601,8 @@ def test_must_skip_invalid_swagger(self, SwaggerEditorMock): def test_must_skip_if_definition_body_is_not_present(self, SwaggerEditorMock): event_id = "id" - properties = { - "RestApiId": {"Ref": "restid"}, - "Path": "/hello", - "Method": "GET" - } - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionUri": "s3://bucket/key", - } - }) + properties = {"RestApiId": {"Ref": "restid"}, "Path": "/hello", "Method": "GET"} + mock_api = SamResource({"Type": "AWS::Serverless::Api", "Properties": {"DefinitionUri": "s3://bucket/key"}}) SwaggerEditorMock.is_valid = Mock() SwaggerEditorMock.is_valid.return_value = False @@ -809,7 +616,7 @@ def test_must_skip_if_definition_body_is_not_present(self, SwaggerEditorMock): self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_called_with(None) - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") SwaggerEditorMock.assert_not_called() template_mock.set.assert_not_called() @@ -817,15 +624,8 @@ def test_must_skip_if_definition_body_is_not_present(self, SwaggerEditorMock): def test_must_skip_if_api_resource_properties_are_invalid(self, SwaggerEditorMock): event_id = "id" - properties = { - "RestApiId": {"Ref": "restid"}, - "Path": "/hello", - "Method": "GET" - } - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": "this is not a valid property" - }) + properties = {"RestApiId": {"Ref": "restid"}, "Path": "/hello", "Method": "GET"} + mock_api = SamResource({"Type": "AWS::Serverless::Api", "Properties": "this is not a valid property"}) SwaggerEditorMock.is_valid = Mock() self.plugin.editor = SwaggerEditorMock @@ -838,7 +638,7 @@ def test_must_skip_if_api_resource_properties_are_invalid(self, SwaggerEditorMoc self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_not_called() - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") SwaggerEditorMock.assert_not_called() template_mock.set.assert_not_called() @@ -846,22 +646,19 @@ def test_must_skip_if_api_resource_properties_are_invalid(self, SwaggerEditorMoc def test_must_skip_if_api_manage_swagger_flag_is_false(self, SwaggerEditorMock): event_id = "id" - properties = { - "RestApiId": {"Ref": "restid"}, - "Path": "/hello", - "Method": "GET" - } + properties = {"RestApiId": {"Ref": "restid"}, "Path": "/hello", "Method": "GET"} original_swagger = {"this": "is a valid swagger"} - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": original_swagger, - "StageName": "prod", - - # Don't manage swagger - "__MANAGE_SWAGGER": False + mock_api = SamResource( + { + "Type": "AWS::Serverless::Api", + "Properties": { + "DefinitionBody": original_swagger, + "StageName": "prod", + # Don't manage swagger + "__MANAGE_SWAGGER": False, + }, } - }) + ) SwaggerEditorMock.is_valid = Mock() self.plugin.editor = SwaggerEditorMock @@ -874,7 +671,7 @@ def test_must_skip_if_api_manage_swagger_flag_is_false(self, SwaggerEditorMock): self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_called_with(original_swagger) - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") SwaggerEditorMock.assert_not_called() template_mock.set.assert_not_called() @@ -882,21 +679,18 @@ def test_must_skip_if_api_manage_swagger_flag_is_false(self, SwaggerEditorMock): def test_must_skip_if_api_manage_swagger_flag_is_not_present(self, SwaggerEditorMock): event_id = "id" - properties = { - "RestApiId": {"Ref": "restid"}, - "Path": "/hello", - "Method": "GET" - } + properties = {"RestApiId": {"Ref": "restid"}, "Path": "/hello", "Method": "GET"} original_swagger = {"this": "is a valid swagger"} - mock_api = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": original_swagger, - "StageName": "prod", - - # __MANAGE_SWAGGER flag is *not* present + mock_api = SamResource( + { + "Type": "AWS::Serverless::Api", + "Properties": { + "DefinitionBody": original_swagger, + "StageName": "prod", + # __MANAGE_SWAGGER flag is *not* present + }, } - }) + ) SwaggerEditorMock.is_valid = Mock() self.plugin.editor = SwaggerEditorMock @@ -909,24 +703,17 @@ def test_must_skip_if_api_manage_swagger_flag_is_not_present(self, SwaggerEditor self.plugin._add_api_to_swagger(event_id, properties, template_mock) SwaggerEditorMock.is_valid.assert_called_with(original_swagger) - template_mock.get.assert_called_with('restid') + template_mock.get.assert_called_with("restid") SwaggerEditorMock.assert_not_called() template_mock.set.assert_not_called() -class TestImplicitRestApiPlugin_maybe_remove_implicit_api(TestCase): +class TestImplicitRestApiPlugin_maybe_remove_implicit_api(TestCase): def setUp(self): self.plugin = ImplicitRestApiPlugin() def test_must_remove_if_no_path_present(self): - resource = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": { - "paths": {} - } - } - }) + resource = SamResource({"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": {"paths": {}}}}) template = Mock() template.get = Mock() template.delete = Mock() @@ -937,14 +724,9 @@ def test_must_remove_if_no_path_present(self): template.delete.assert_called_with(IMPLICIT_API_LOGICAL_ID) def test_must_skip_if_path_present(self): - resource = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": { - "paths": {"a": "b"} - } - } - }) + resource = SamResource( + {"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": {"paths": {"a": "b"}}}} + ) template = Mock() template.get = Mock() template.delete = Mock() @@ -957,14 +739,7 @@ def test_must_skip_if_path_present(self): template.delete.assert_not_called() def test_must_restore_if_existing_resource_present(self): - resource = SamResource({ - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": { - "paths": {} - } - } - }) + resource = SamResource({"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": {"paths": {}}}}) template = Mock() template.get = Mock() template.set = Mock() diff --git a/tests/plugins/application/test_serverless_app_plugin.py b/tests/plugins/application/test_serverless_app_plugin.py index e8637095f0..3dd4449c47 100644 --- a/tests/plugins/application/test_serverless_app_plugin.py +++ b/tests/plugins/application/test_serverless_app_plugin.py @@ -10,56 +10,54 @@ # TODO: run tests when AWS CLI is not configured (so they can run in brazil) -MOCK_TEMPLATE_URL = 'https://awsserverlessrepo-changesets-xxx.s3.amazonaws.com/pre-signed-url' -MOCK_TEMPLATE_ID = 'id-xx-xx' -STATUS_ACTIVE = 'ACTIVE' -STATUS_PREPARING = 'PREPARING' -STATUS_EXPIRED = 'EXPIRED' +MOCK_TEMPLATE_URL = "https://awsserverlessrepo-changesets-xxx.s3.amazonaws.com/pre-signed-url" +MOCK_TEMPLATE_ID = "id-xx-xx" +STATUS_ACTIVE = "ACTIVE" +STATUS_PREPARING = "PREPARING" +STATUS_EXPIRED = "EXPIRED" + def mock_create_cloud_formation_template(ApplicationId=None, SemanticVersion=None): message = { - 'ApplicationId': ApplicationId, - 'SemanticVersion': SemanticVersion, - 'Status': STATUS_ACTIVE, - 'TemplateId': MOCK_TEMPLATE_ID, - 'TemplateUrl': MOCK_TEMPLATE_URL + "ApplicationId": ApplicationId, + "SemanticVersion": SemanticVersion, + "Status": STATUS_ACTIVE, + "TemplateId": MOCK_TEMPLATE_ID, + "TemplateUrl": MOCK_TEMPLATE_URL, } return message def mock_get_application(ApplicationId=None, SemanticVersion=None): message = { - 'ApplicationId': ApplicationId, - 'Author': 'AWS', - 'Description': 'Application description', - 'Name': 'application-name', - 'ParameterDefinitions': [{ - 'Name': 'Parameter1', - 'ReferencedByResources': ['resource1'], - 'Type': 'String' - }], - 'SemanticVersion': SemanticVersion + "ApplicationId": ApplicationId, + "Author": "AWS", + "Description": "Application description", + "Name": "application-name", + "ParameterDefinitions": [{"Name": "Parameter1", "ReferencedByResources": ["resource1"], "Type": "String"}], + "SemanticVersion": SemanticVersion, } return message def mock_get_cloud_formation_template(ApplicationId=None, TemplateId=None): message = { - 'ApplicationId': ApplicationId, - 'SemanticVersion': '1.0.0', - 'Status': STATUS_ACTIVE, - 'TemplateId': TemplateId, - 'TemplateUrl': MOCK_TEMPLATE_URL + "ApplicationId": ApplicationId, + "SemanticVersion": "1.0.0", + "Status": STATUS_ACTIVE, + "TemplateId": TemplateId, + "TemplateUrl": MOCK_TEMPLATE_URL, } return message + def mock_get_region(self, service_name, region_name): - return 'us-east-1' + return "us-east-1" -class TestServerlessAppPlugin_init(TestCase): +class TestServerlessAppPlugin_init(TestCase): def setUp(self): - client = boto3.client('serverlessrepo', region_name='us-east-1') + client = boto3.client("serverlessrepo", region_name="us-east-1") self.plugin = ServerlessAppPlugin(sar_client=client) def test_plugin_must_setup_correct_name(self): @@ -67,53 +65,55 @@ def test_plugin_must_setup_correct_name(self): expected_name = "ServerlessAppPlugin" self.assertEqual(self.plugin.name, expected_name) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_plugin_default_values(self): self.assertEqual(self.plugin._wait_for_template_active_status, False) self.assertEqual(self.plugin._validate_only, False) self.assertTrue(self.plugin._sar_client is not None) # For some reason, `isinstance` or comparing classes did not work here - self.assertEqual(str(self.plugin._sar_client.__class__), str(boto3.client('serverlessrepo').__class__)) + self.assertEqual(str(self.plugin._sar_client.__class__), str(boto3.client("serverlessrepo").__class__)) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_plugin_accepts_different_sar_client(self): - client = boto3.client('serverlessrepo', endpoint_url = 'https://example.com') + client = boto3.client("serverlessrepo", endpoint_url="https://example.com") self.plugin = ServerlessAppPlugin(sar_client=client) self.assertEqual(self.plugin._sar_client, client) self.assertEqual(self.plugin._sar_client._endpoint, client._endpoint) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_plugin_accepts_flags(self): self.plugin = ServerlessAppPlugin(wait_for_template_active_status=True) self.assertEqual(self.plugin._wait_for_template_active_status, True) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_plugin_invalid_configuration_raises_exception(self): with self.assertRaises(InvalidPluginException): plugin = ServerlessAppPlugin(wait_for_template_active_status=True, validate_only=True) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_plugin_accepts_parameters(self): - parameters = {"a":"b"} + parameters = {"a": "b"} self.plugin = ServerlessAppPlugin(parameters=parameters) self.assertEqual(self.plugin._parameters, parameters) class TestServerlessAppPlugin_on_before_transform_template_translate(TestCase): - - def setUp(self): - client = boto3.client('serverlessrepo', region_name='us-east-1') + client = boto3.client("serverlessrepo", region_name="us-east-1") self.plugin = ServerlessAppPlugin(sar_client=client) @patch("samtranslator.plugins.application.serverless_app_plugin.SamTemplate") - @patch('botocore.client.BaseClient._make_api_call', mock_create_cloud_formation_template) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.BaseClient._make_api_call", mock_create_cloud_formation_template) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_must_process_applications(self, SamTemplateMock): - self.plugin = ServerlessAppPlugin(sar_client=boto3.client('serverlessrepo')) + self.plugin = ServerlessAppPlugin(sar_client=boto3.client("serverlessrepo")) template_dict = {"a": "b"} - app_resources = [("id1", ApplicationResource(app_id = 'id1')), ("id2", ApplicationResource(app_id='id2')), ("id3", ApplicationResource())] + app_resources = [ + ("id1", ApplicationResource(app_id="id1")), + ("id2", ApplicationResource(app_id="id2")), + ("id3", ApplicationResource()), + ] sam_template = Mock() SamTemplateMock.return_value = sam_template @@ -127,15 +127,18 @@ def test_must_process_applications(self, SamTemplateMock): # Make sure this is called only for Apis sam_template.iterate.assert_called_with("AWS::Serverless::Application") - @patch("samtranslator.plugins.application.serverless_app_plugin.SamTemplate") - @patch('botocore.client.BaseClient._make_api_call', mock_get_application) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.BaseClient._make_api_call", mock_get_application) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_must_process_applications_validate(self, SamTemplateMock): self.plugin = ServerlessAppPlugin(validate_only=True) template_dict = {"a": "b"} - app_resources = [("id1", ApplicationResource(app_id = 'id1')), ("id2", ApplicationResource(app_id='id2')), ("id3", ApplicationResource())] + app_resources = [ + ("id1", ApplicationResource(app_id="id1")), + ("id2", ApplicationResource(app_id="id2")), + ("id3", ApplicationResource()), + ] sam_template = Mock() SamTemplateMock.return_value = sam_template @@ -148,15 +151,18 @@ def test_must_process_applications_validate(self, SamTemplateMock): # Make sure this is called only for Apis sam_template.iterate.assert_called_with("AWS::Serverless::Application") - @patch("samtranslator.plugins.application.serverless_app_plugin.SamTemplate") - @patch('botocore.client.BaseClient._make_api_call', mock_create_cloud_formation_template) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.BaseClient._make_api_call", mock_create_cloud_formation_template) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_process_invalid_applications(self, SamTemplateMock): - self.plugin = ServerlessAppPlugin(sar_client=boto3.client('serverlessrepo', region_name='us-east-1')) + self.plugin = ServerlessAppPlugin(sar_client=boto3.client("serverlessrepo", region_name="us-east-1")) template_dict = {"a": "b"} - app_resources = [("id1", ApplicationResource(app_id = '')), ("id2", ApplicationResource(app_id=None)), ("id3", ApplicationResource(app_id='id3', semver=None))] + app_resources = [ + ("id1", ApplicationResource(app_id="")), + ("id2", ApplicationResource(app_id=None)), + ("id3", ApplicationResource(app_id="id3", semver=None)), + ] sam_template = Mock() SamTemplateMock.return_value = sam_template @@ -170,14 +176,17 @@ def test_process_invalid_applications(self, SamTemplateMock): # Make sure this is called only for Apis sam_template.iterate.assert_called_with("AWS::Serverless::Application") - @patch("samtranslator.plugins.application.serverless_app_plugin.SamTemplate") - @patch('botocore.client.BaseClient._make_api_call', mock_get_application) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.BaseClient._make_api_call", mock_get_application) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_process_invalid_applications_validate(self, SamTemplateMock): self.plugin = ServerlessAppPlugin(validate_only=True) template_dict = {"a": "b"} - app_resources = [("id1", ApplicationResource(app_id = '')), ("id2", ApplicationResource(app_id=None)), ("id3", ApplicationResource(app_id='id3', semver=None))] + app_resources = [ + ("id1", ApplicationResource(app_id="")), + ("id2", ApplicationResource(app_id=None)), + ("id3", ApplicationResource(app_id="id3", semver=None)), + ] sam_template = Mock() SamTemplateMock.return_value = sam_template @@ -191,27 +200,19 @@ def test_process_invalid_applications_validate(self, SamTemplateMock): # Make sure this is called only for Apis sam_template.iterate.assert_called_with("AWS::Serverless::Application") - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_sar_service_calls(self): service_call_lambda = mock_get_application - logical_id = 'logical_id' - app_id = 'app_id' - semver = '1.0.0' + logical_id = "logical_id" + app_id = "app_id" + semver = "1.0.0" response = self.plugin._sar_service_call(service_call_lambda, logical_id, app_id, semver) - self.assertEqual(app_id, response['ApplicationId']) + self.assertEqual(app_id, response["ApplicationId"]) def test_resolve_intrinsics(self): self.plugin = ServerlessAppPlugin(parameters={"AWS::Region": "us-east-1"}) - mappings = { - "MapA":{ - "us-east-1": { - "SecondLevelKey1": "value1" - } - } - } - input = { - "Fn::FindInMap": ["MapA", {"Ref": "AWS::Region"}, "SecondLevelKey1"] - } + mappings = {"MapA": {"us-east-1": {"SecondLevelKey1": "value1"}}} + input = {"Fn::FindInMap": ["MapA", {"Ref": "AWS::Region"}, "SecondLevelKey1"]} intrinsic_resolvers = self.plugin._get_intrinsic_resolvers(mappings) output = self.plugin._resolve_location_value(input, intrinsic_resolvers) @@ -219,47 +220,38 @@ def test_resolve_intrinsics(self): class ApplicationResource(object): - def __init__(self, app_id='app_id', semver='1.3.5'): - self.properties = { - 'ApplicationId': app_id, - 'SemanticVersion': semver - } - - - - - - + def __init__(self, app_id="app_id", semver="1.3.5"): + self.properties = {"ApplicationId": app_id, "SemanticVersion": semver} -#class TestServerlessAppPlugin_on_before_transform_resource(TestCase): +# class TestServerlessAppPlugin_on_before_transform_resource(TestCase): # def setUp(self): # self.plugin = ServerlessAppPlugin() - # TODO: test this lifecycle event - - # @parameterized.expand( - # itertools.product([ - # ServerlessAppPlugin(), - # ServerlessAppPlugin(wait_for_template_active_status=True), - # ]) - # ) - # @patch("samtranslator.plugins.application.serverless_app_plugin.SamTemplate") - # @patch('botocore.client.BaseClient._make_api_call', mock_create_cloud_formation_template) - # def test_process_invalid_applications(self, plugin, SamTemplateMock): - # self.plugin = plugin - # template_dict = {"a": "b"} - # app_resources = [("id1", ApplicationResource(app_id = '')), ("id2", ApplicationResource(app_id=None))] - - # sam_template = Mock() - # SamTemplateMock.return_value = sam_template - # sam_template.iterate = Mock() - # sam_template.iterate.return_value = app_resources - - # self.plugin.on_before_transform_template(template_dict) - - # self.plugin.on_before_transform_resource(app_resources[0][0], 'AWS::Serverless::Application', app_resources[0][1].properties) +# TODO: test this lifecycle event + +# @parameterized.expand( +# itertools.product([ +# ServerlessAppPlugin(), +# ServerlessAppPlugin(wait_for_template_active_status=True), +# ]) +# ) +# @patch("samtranslator.plugins.application.serverless_app_plugin.SamTemplate") +# @patch('botocore.client.BaseClient._make_api_call', mock_create_cloud_formation_template) +# def test_process_invalid_applications(self, plugin, SamTemplateMock): +# self.plugin = plugin +# template_dict = {"a": "b"} +# app_resources = [("id1", ApplicationResource(app_id = '')), ("id2", ApplicationResource(app_id=None))] + +# sam_template = Mock() +# SamTemplateMock.return_value = sam_template +# sam_template.iterate = Mock() +# sam_template.iterate.return_value = app_resources + +# self.plugin.on_before_transform_template(template_dict) + +# self.plugin.on_before_transform_resource(app_resources[0][0], 'AWS::Serverless::Application', app_resources[0][1].properties) # class TestServerlessAppPlugin_on_after_transform_template(TestCase): diff --git a/tests/plugins/globals/test_globals.py b/tests/plugins/globals/test_globals.py index 9846e82957..eef6123ec5 100644 --- a/tests/plugins/globals/test_globals.py +++ b/tests/plugins/globals/test_globals.py @@ -5,387 +5,178 @@ from samtranslator.plugins.globals.globals import GlobalProperties, Globals, InvalidGlobalsSectionException + class GlobalPropertiesTestCases(object): dict_with_single_level_should_be_merged = { - "global": { - "a": 1, - "b": 2 - }, - "local": { - "a": "foo", - "c": 3, - "d": 4 - }, - "expected_output": { - "a": "foo", - "b": 2, - "c": 3, - "d": 4 - } + "global": {"a": 1, "b": 2}, + "local": {"a": "foo", "c": 3, "d": 4}, + "expected_output": {"a": "foo", "b": 2, "c": 3, "d": 4}, } dict_keys_are_case_sensitive = { - "global": { - "banana": "is tasty" - }, - "local": { - "BaNaNa": "is not tasty" - }, - "expected_output": { - "banana": "is tasty", - "BaNaNa": "is not tasty" - } + "global": {"banana": "is tasty"}, + "local": {"BaNaNa": "is not tasty"}, + "expected_output": {"banana": "is tasty", "BaNaNa": "is not tasty"}, } - dict_properties_with_different_types_must_be_overridden_str_and_dict = { - "global": { - "key": "foo" - }, - "local": { - "key": {"a": "b"} - }, - "expected_output": { - "key": {"a": "b"} - } + "global": {"key": "foo"}, + "local": {"key": {"a": "b"}}, + "expected_output": {"key": {"a": "b"}}, } dict_properties_with_different_types_must_be_overridden_boolean_and_int = { - "global": { - "key": True - }, - "local": { - "key": 1 - }, - "expected_output": { - "key": 1 - } + "global": {"key": True}, + "local": {"key": 1}, + "expected_output": {"key": 1}, } dict_properties_with_different_types_must_be_overridden_dict_and_array = { - "global": { - "key": {"a": "b"} - }, - "local": { - "key": ["a"] - }, - "expected_output": { - "key": ["a"] - } + "global": {"key": {"a": "b"}}, + "local": {"key": ["a"]}, + "expected_output": {"key": ["a"]}, } - dict_with_empty_local_must_merge = { - "global": { - "a": "b" - }, - - "local": {}, - - "expected_output": { - "a": "b" - } - } + dict_with_empty_local_must_merge = {"global": {"a": "b"}, "local": {}, "expected_output": {"a": "b"}} nested_dict_keys_should_be_merged = { - "global": { - "key1": { - "key2": { - "key3": { - "key4": "value" - } - } - } - }, - "local": { - "key1": { - "key2": { - "key3": { - "key4": "local value" - }, - }, - } - }, - "expected_output": { - "key1": { - "key2": { - "key3": { - "key4": "local value" - }, - }, - } - } + "global": {"key1": {"key2": {"key3": {"key4": "value"}}}}, + "local": {"key1": {"key2": {"key3": {"key4": "local value"}}}}, + "expected_output": {"key1": {"key2": {"key3": {"key4": "local value"}}}}, } nested_dict_with_different_levels_should_be_merged = { - "global": { - "key1": { - "key2": { - "key3": "value3" - }, - "globalOnlyKey": "global value" - } - }, - "local": { - "key1": { - "key2": "foo", - "localOnlyKey": "local value" - } - }, + "global": {"key1": {"key2": {"key3": "value3"}, "globalOnlyKey": "global value"}}, + "local": {"key1": {"key2": "foo", "localOnlyKey": "local value"}}, "expected_output": { "key1": { # Key2 does not recurse any further "key2": "foo", "globalOnlyKey": "global value", - "localOnlyKey": "local value" + "localOnlyKey": "local value", } - } + }, } nested_dicts_with_non_overridden_keys_should_be_copied = { - "global": { - "key1": { - "key2": { - "key3": { - "key4": "value" - } - }, - "globalOnly": { - "globalOnlyKey": "globalOnlyValue" - } - } - }, + "global": {"key1": {"key2": {"key3": {"key4": "value"}}, "globalOnly": {"globalOnlyKey": "globalOnlyValue"}}}, "local": { "key1": { - "key2": { - "key3": { - "localkey4": "other value 4" - }, - "localkey3": "other value 3" - }, - + "key2": {"key3": {"localkey4": "other value 4"}, "localkey3": "other value 3"}, "localkey2": "other value 2", } }, "expected_output": { "key1": { - "key2": { - "key3": { - "key4": "value", - "localkey4": "other value 4" - }, - "localkey3": "other value 3" - }, - + "key2": {"key3": {"key4": "value", "localkey4": "other value 4"}, "localkey3": "other value 3"}, "localkey2": "other value 2", - "globalOnly": { - "globalOnlyKey": "globalOnlyValue" - } + "globalOnly": {"globalOnlyKey": "globalOnlyValue"}, } - } + }, } arrays_with_mutually_exclusive_elements_must_be_concatenated = { "global": [1, 2, 3], "local": [11, 12, 13], - "expected_output": [ - 1,2,3, - 11,12,13 - ] + "expected_output": [1, 2, 3, 11, 12, 13], } arrays_with_duplicate_elements_must_be_concatenated = { "global": ["a", "b", "c", "z"], "local": ["a", "b", "x", "y", "z"], - "expected_output": [ - "a", "b", "c", "z", - "a", "b", "x", "y", "z" - ] + "expected_output": ["a", "b", "c", "z", "a", "b", "x", "y", "z"], } arrays_with_nested_dict_must_be_concatenated = { "global": [{"a": 1}, {"b": 2}], "local": [{"x": 1}, {"y": 2}], - "expected_output": [ - {"a": 1}, {"b": 2}, - {"x": 1}, {"y": 2} - ] + "expected_output": [{"a": 1}, {"b": 2}, {"x": 1}, {"y": 2}], } arrays_with_mixed_element_types_must_be_concatenated = { "global": [1, 2, "foo", True, {"x": "y"}, ["nested", "array"]], "local": [False, 9, 8, "bar"], - "expected_output": [ - 1, 2, "foo", True, {"x": "y"}, ["nested", "array"], - False, 9, 8, "bar" - ] + "expected_output": [1, 2, "foo", True, {"x": "y"}, ["nested", "array"], False, 9, 8, "bar"], } arrays_with_exactly_same_values_must_be_concatenated = { "global": [{"a": 1}, {"b": 2}, "foo", 1, 2, True, False], "local": [{"a": 1}, {"b": 2}, "foo", 1, 2, True, False], - "expected_output": [ - {"a": 1}, {"b": 2}, "foo", 1, 2, True, False, - {"a": 1}, {"b": 2}, "foo", 1, 2, True, False - ] + "expected_output": [{"a": 1}, {"b": 2}, "foo", 1, 2, True, False, {"a": 1}, {"b": 2}, "foo", 1, 2, True, False], } # Arrays are concatenated. Other keys in dictionary are merged nested_dict_with_array_values_must_be_merged_and_concatenated = { - "global": { - "key": "global value", - "nested": { - "array_key": [1, 2, 3], - }, - "globalOnlyKey": "global value" - }, - "local": { - "key": "local value", - "nested": { - "array_key": [8, 9, 10], - }, - "localOnlyKey": "local value" - }, + "global": {"key": "global value", "nested": {"array_key": [1, 2, 3]}, "globalOnlyKey": "global value"}, + "local": {"key": "local value", "nested": {"array_key": [8, 9, 10]}, "localOnlyKey": "local value"}, "expected_output": { "key": "local value", - "nested": { - "array_key": [ - 1, 2, 3, - 8, 9, 10 - ], - }, + "nested": {"array_key": [1, 2, 3, 8, 9, 10]}, "globalOnlyKey": "global value", - "localOnlyKey": "local value" - } + "localOnlyKey": "local value", + }, } intrinsic_function_must_be_overridden = { - "global": { - "Ref": "foo" - }, - "local": { - "Fn::Spooky": "bar" - }, - "expected_output": { - "Fn::Spooky": "bar" - } + "global": {"Ref": "foo"}, + "local": {"Fn::Spooky": "bar"}, + "expected_output": {"Fn::Spooky": "bar"}, } intrinsic_function_in_global_must_override_dict_value_in_local = { - "global": { - "Ref": "foo" - }, - "local": { - "a": "b" - }, - "expected_output": { - "a": "b" - } + "global": {"Ref": "foo"}, + "local": {"a": "b"}, + "expected_output": {"a": "b"}, } intrinsic_function_in_local_must_override_dict_value_in_global = { - "global": { - "a": "b" - }, - "local": { - "Fn::Something": "value" - }, - "expected_output": { - "Fn::Something": "value" - } + "global": {"a": "b"}, + "local": {"Fn::Something": "value"}, + "expected_output": {"Fn::Something": "value"}, } intrinsic_function_in_nested_dict_must_be_overridden = { - "global": { - "key1": { - "key2": { - "key3": { - "Ref": "foo" - }, - "globalOnlyKey": "global value" - } - } - }, - - "local": { - "key1": { - "key2": { - "key3": { - "Fn::Something": "New value" - } - }, - } - }, - + "global": {"key1": {"key2": {"key3": {"Ref": "foo"}, "globalOnlyKey": "global value"}}}, + "local": {"key1": {"key2": {"key3": {"Fn::Something": "New value"}}}}, "expected_output": { - "key1": { - "key2": { - "key3": { - "Fn::Something": "New value" - }, - "globalOnlyKey": "global value" - }, - } - } + "key1": {"key2": {"key3": {"Fn::Something": "New value"}, "globalOnlyKey": "global value"}} + }, } invalid_intrinsic_function_dict_must_be_merged = { "global": { # This is not an intrinsic function because the dict contains two keys - "Ref": "foo", - "key": "global value" - }, - - "local": { - "Fn::Something": "bar", - "other": "local value" - }, - - "expected_output": { "Ref": "foo", "key": "global value", - "Fn::Something": "bar", - "other": "local value" - } + }, + "local": {"Fn::Something": "bar", "other": "local value"}, + "expected_output": {"Ref": "foo", "key": "global value", "Fn::Something": "bar", "other": "local value"}, } intrinsic_function_in_local_must_override_invalid_intrinsic_in_global = { "global": { # This is not an intrinsic function because the dict contains two keys "Ref": "foo", - "key": "global value" + "key": "global value", }, - "local": { # This is an intrinsic function which essentially resolves to a primitive type. # So local is primitive type whereas global is a dictionary. Prefer local "Fn::Something": "bar" }, - - "expected_output": { - "Fn::Something": "bar" - } + "expected_output": {"Fn::Something": "bar"}, } - primitive_type_inputs_must_be_handled = { - "global": "input string", - "local": 123, - "expected_output": 123 - } + primitive_type_inputs_must_be_handled = {"global": "input string", "local": 123, "expected_output": 123} - mixed_type_inputs_must_be_handled = { - "global": {"a": "b"}, - "local": [1, 2, 3], - "expected_output": [1, 2, 3] - } + mixed_type_inputs_must_be_handled = {"global": {"a": "b"}, "local": [1, 2, 3], "expected_output": [1, 2, 3]} class TestGlobalPropertiesMerge(TestCase): # Get all attributes of the test case object which is not a built-in method like __str__ - @parameterized.expand([d for d in dir(GlobalPropertiesTestCases) - if not d.startswith("__") - ]) + @parameterized.expand([d for d in dir(GlobalPropertiesTestCases) if not d.startswith("__")]) def test_global_properties_merge(self, testcase): configuration = getattr(GlobalPropertiesTestCases, testcase) @@ -397,8 +188,8 @@ def test_global_properties_merge(self, testcase): self.assertEqual(actual, configuration["expected_output"]) -class TestGlobalsPropertiesEdgeCases(TestCase): +class TestGlobalsPropertiesEdgeCases(TestCase): @patch.object(GlobalProperties, "_token_of") def test_merge_with_objects_of_unsupported_token_type(self, token_of_mock): @@ -409,12 +200,12 @@ def test_merge_with_objects_of_unsupported_token_type(self, token_of_mock): # Raise type error because token type is invalid properties.merge("local value") -class TestGlobalsObject(TestCase): +class TestGlobalsObject(TestCase): def setUp(self): self._originals = { "resource_prefix": Globals._RESOURCE_PREFIX, - "supported_properties": Globals.supported_properties + "supported_properties": Globals.supported_properties, } Globals._RESOURCE_PREFIX = "prefix_" Globals.supported_properties = { @@ -424,14 +215,8 @@ def setUp(self): self.template = { "Globals": { - "type1": { - "prop1": "value1", - "prop2": "value2" - }, - "type2": { - "otherprop1": "value1", - "otherprop2": "value2" - } + "type1": {"prop1": "value1", "prop2": "value2"}, + "type2": {"otherprop1": "value1", "otherprop2": "value2"}, } } @@ -449,63 +234,37 @@ def test_parse_should_parse_all_known_resource_types(self): self.assertTrue("prefix_type2" in parsed_globals) self.assertEqual(self.template["Globals"]["type2"], parsed_globals["prefix_type2"].global_properties) - def test_parse_should_error_if_globals_is_not_dict(self): - template = { - "Globals": "hello" - } + template = {"Globals": "hello"} with self.assertRaises(InvalidGlobalsSectionException): Globals(template) def test_parse_should_error_if_globals_contains_unknown_types(self): - template = { - "Globals": { - "random_type": { - "key": "value" - }, - "type1": { - "key": "value" - } - } - } + template = {"Globals": {"random_type": {"key": "value"}, "type1": {"key": "value"}}} with self.assertRaises(InvalidGlobalsSectionException): Globals(template) def test_parse_should_error_if_globals_contains_unknown_properties_of_known_type(self): - template = { - "Globals": { - "type1": { - "unknown_property": "value" - } - } - } + template = {"Globals": {"type1": {"unknown_property": "value"}}} with self.assertRaises(InvalidGlobalsSectionException): Globals(template) def test_parse_should_error_if_value_is_not_dictionary(self): - template = { - "Globals": { - "type1": "string value" - } - } + template = {"Globals": {"type1": "string value"}} with self.assertRaises(InvalidGlobalsSectionException): Globals(template) def test_parse_should_not_error_if_value_is_empty(self): - template = { - "Globals": { - "type1": {} # empty value - } - } + template = {"Globals": {"type1": {}}} # empty value globals = Globals(template) parsed = globals._parse(template["Globals"]) @@ -515,9 +274,7 @@ def test_parse_should_not_error_if_value_is_empty(self): def test_init_without_globals_section_in_template(self): - template = { - "a": "b" - } + template = {"a": "b"} global_obj = Globals(template) self.assertEqual({}, global_obj.template_globals) @@ -530,13 +287,9 @@ def test_del_section_with_globals_section_in_template(self): self.assertEqual(expected, template) def test_del_section_with_no_globals_section_in_template(self): - template = { - "a": "b" - } + template = {"a": "b"} - expected = { - "a": "b" - } + expected = {"a": "b"} Globals.del_section(template) self.assertEqual(expected, template) @@ -546,10 +299,7 @@ def test_merge_must_actually_do_merge(self, parse_mock): type1_mock = Mock() type2_mock = Mock() - parse_mock.return_value = { - "type1": type1_mock, - "type2": type2_mock, - } + parse_mock.return_value = {"type1": type1_mock, "type2": type2_mock} local_properties = {"a": "b"} expected = "some merged value" @@ -567,9 +317,7 @@ def test_merge_must_actually_do_merge(self, parse_mock): def test_merge_must_skip_unsupported_types(self, parse_mock): type1_mock = Mock() - parse_mock.return_value = { - "type1": type1_mock - } + parse_mock.return_value = {"type1": type1_mock} local_properties = {"a": "b"} expected = {"a": "b"} @@ -585,8 +333,7 @@ def test_merge_must_skip_unsupported_types(self, parse_mock): @patch.object(Globals, "_parse") def test_merge_must_skip_with_no_types(self, parse_mock): - parse_mock.return_value = { - } + parse_mock.return_value = {} local_properties = {"a": "b"} expected = {"a": "b"} @@ -601,18 +348,9 @@ def test_merge_must_skip_with_no_types(self, parse_mock): def test_merge_end_to_end_on_known_type1(self): type = "prefix_type1" - properties = { - "prop1": "overridden value", - "a": "b", - "key": [1,2,3] - } + properties = {"prop1": "overridden value", "a": "b", "key": [1, 2, 3]} - expected = { - "prop1": "overridden value", - "prop2": "value2", # inherited from global - "a": "b", - "key": [1,2,3] - } + expected = {"prop1": "overridden value", "prop2": "value2", "a": "b", "key": [1, 2, 3]} # inherited from global globals = Globals(self.template) result = globals.merge(type, properties) @@ -622,16 +360,13 @@ def test_merge_end_to_end_on_known_type1(self): def test_merge_end_to_end_on_known_type2(self): type = "prefix_type2" - properties = { - "a": "b", - "key": [1,2,3] - } + properties = {"a": "b", "key": [1, 2, 3]} expected = { "otherprop1": "value1", # inherited from global "otherprop2": "value2", # inherited from global "a": "b", - "key": [1,2,3] + "key": [1, 2, 3], } globals = Globals(self.template) @@ -642,30 +377,19 @@ def test_merge_end_to_end_on_known_type2(self): def test_merge_end_to_end_unknown_type(self): type = "some unknown type" - properties = { - "a": "b", - "key": [1,2,3] - } + properties = {"a": "b", "key": [1, 2, 3]} # Output equals input - expected = { - "a": "b", - "key": [1,2,3] - } + expected = {"a": "b", "key": [1, 2, 3]} globals = Globals(self.template) result = globals.merge(type, properties) self.assertEqual(expected, result) + class TestGlobalsOpenApi(TestCase): - template = { - "Globals": { - "Api": { - "OpenApiVersion": "3.0" - } - } - } + template = {"Globals": {"Api": {"OpenApiVersion": "3.0"}}} table_driven = [ { @@ -677,10 +401,8 @@ class TestGlobalsOpenApi(TestCase): "Properties": { "__MANAGE_SWAGGER": True, "OpenApiVersion": "3.0", - "DefinitionBody": { - "swagger": "2.0" - } - } + "DefinitionBody": {"swagger": "2.0"}, + }, } } }, @@ -691,40 +413,24 @@ class TestGlobalsOpenApi(TestCase): "Properties": { "__MANAGE_SWAGGER": True, "OpenApiVersion": "3.0", - "DefinitionBody": { - "openapi": "3.0" - } - } + "DefinitionBody": {"openapi": "3.0"}, + }, } } - } + }, }, { "name": "no openapi", "input": { "Resources": { - "MyApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": { - "swagger": "2.0" - } - } - } + "MyApi": {"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": {"swagger": "2.0"}}} } }, "expected": { "Resources": { - "MyApi": { - "Type": "AWS::Serverless::Api", - "Properties": { - "DefinitionBody": { - "swagger": "2.0" - } - } - } + "MyApi": {"Type": "AWS::Serverless::Api", "Properties": {"DefinitionBody": {"swagger": "2.0"}}} } - } + }, }, { "name": "Openapi set to 2.0", @@ -735,10 +441,8 @@ class TestGlobalsOpenApi(TestCase): "Properties": { "__MANAGE_SWAGGER": True, "OpenApiVersion": "2.0", - "DefinitionBody": { - "swagger": "2.0" - } - } + "DefinitionBody": {"swagger": "2.0"}, + }, } } }, @@ -749,13 +453,11 @@ class TestGlobalsOpenApi(TestCase): "Properties": { "__MANAGE_SWAGGER": True, "OpenApiVersion": "2.0", - "DefinitionBody": { - "swagger": "2.0" - } - } + "DefinitionBody": {"swagger": "2.0"}, + }, } } - } + }, }, { "name": "No definition body", @@ -763,10 +465,7 @@ class TestGlobalsOpenApi(TestCase): "Resources": { "MyApi": { "Type": "AWS::Serverless::Api", - "Properties": { - "__MANAGE_SWAGGER": True, - "OpenApiVersion": "3.0" - } + "Properties": {"__MANAGE_SWAGGER": True, "OpenApiVersion": "3.0"}, } } }, @@ -774,13 +473,10 @@ class TestGlobalsOpenApi(TestCase): "Resources": { "MyApi": { "Type": "AWS::Serverless::Api", - "Properties": { - "__MANAGE_SWAGGER": True, - "OpenApiVersion": "3.0" - } + "Properties": {"__MANAGE_SWAGGER": True, "OpenApiVersion": "3.0"}, } } - } + }, }, { "name": "ignore customer defined swagger", @@ -788,12 +484,7 @@ class TestGlobalsOpenApi(TestCase): "Resources": { "MyApi": { "Type": "AWS::Serverless::Api", - "Properties": { - "OpenApiVersion": "3.0", - "DefinitionBody": { - "swagger": "2.0" - } - } + "Properties": {"OpenApiVersion": "3.0", "DefinitionBody": {"swagger": "2.0"}}, } } }, @@ -801,27 +492,16 @@ class TestGlobalsOpenApi(TestCase): "Resources": { "MyApi": { "Type": "AWS::Serverless::Api", - "Properties": { - "OpenApiVersion": "3.0", - "DefinitionBody": { - "swagger": "2.0" - } - } + "Properties": {"OpenApiVersion": "3.0", "DefinitionBody": {"swagger": "2.0"}}, } } - } + }, }, { "name": "No Resources", - "input": { - "some": "other", - "swagger": "donottouch" - }, - "expected": { - "some": "other", - "swagger": "donottouch" - } - } + "input": {"some": "other", "swagger": "donottouch"}, + "expected": {"some": "other", "swagger": "donottouch"}, + }, ] def test_openapi_postprocess(self): diff --git a/tests/plugins/globals/test_globals_plugin.py b/tests/plugins/globals/test_globals_plugin.py index 354c9d1638..925211d018 100644 --- a/tests/plugins/globals/test_globals_plugin.py +++ b/tests/plugins/globals/test_globals_plugin.py @@ -6,6 +6,7 @@ from samtranslator.plugins.globals.globals_plugin import GlobalsPlugin from samtranslator.plugins.globals.globals import InvalidGlobalsSectionException + class TestGlobalsPlugin(TestCase): """ Unit testing Globals Plugin diff --git a/tests/plugins/policies/test_policy_templates_plugin.py b/tests/plugins/policies/test_policy_templates_plugin.py index e3d254b62c..9faad31d6c 100644 --- a/tests/plugins/policies/test_policy_templates_plugin.py +++ b/tests/plugins/policies/test_policy_templates_plugin.py @@ -9,7 +9,6 @@ class TestPolicyTemplatesForFunctionPlugin(TestCase): - def setUp(self): self._policy_template_processor_mock = Mock() self.plugin = PolicyTemplatesForFunctionPlugin(self._policy_template_processor_mock) @@ -41,19 +40,9 @@ def test_on_before_transform_resource_must_work_on_every_policy_template(self, f function_policies_class_mock.return_value = function_policies_obj_mock function_policies_class_mock.POLICIES_PROPERTY_NAME = "Policies" - template1 = { - "MyTemplate1": { - "Param1": "value1" - } - } - template2 = { - "MyTemplate2": { - "Param2": "value2" - } - } - resource_properties = { - "Policies": [template1, template2] - } + template1 = {"MyTemplate1": {"Param1": "value1"}} + template2 = {"MyTemplate2": {"Param2": "value2"}} + resource_properties = {"Policies": [template1, template2]} policies = [ PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE), @@ -76,10 +65,9 @@ def test_on_before_transform_resource_must_work_on_every_policy_template(self, f # This will overwrite the resource_properties input array self.assertEqual(expected, resource_properties["Policies"]) function_policies_obj_mock.get.assert_called_once_with() - self._policy_template_processor_mock.convert.assert_has_calls([ - call("MyTemplate1", {"Param1": "value1"}), - call("MyTemplate2", {"Param2": "value2"}) - ]) + self._policy_template_processor_mock.convert.assert_has_calls( + [call("MyTemplate1", {"Param1": "value1"}), call("MyTemplate2", {"Param2": "value2"})] + ) @patch("samtranslator.plugins.policies.policy_templates_plugin.FunctionPolicies") def test_on_before_transform_resource_must_skip_non_policy_templates(self, function_policies_class_mock): @@ -91,20 +79,10 @@ def test_on_before_transform_resource_must_skip_non_policy_templates(self, funct function_policies_class_mock.return_value = function_policies_obj_mock function_policies_class_mock.POLICIES_PROPERTY_NAME = "Policies" - template1 = { - "MyTemplate1": { - "Param1": "value1" - } - } - template2 = { - "MyTemplate2": { - "Param2": "value2" - } - } + template1 = {"MyTemplate1": {"Param1": "value1"}} + template2 = {"MyTemplate2": {"Param2": "value2"}} regular_policy = {"regular policy": "something"} - resource_properties = { - "Policies": [template1, regular_policy, template2] - } + resource_properties = {"Policies": [template1, regular_policy, template2]} policies = [ PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE), @@ -122,17 +100,19 @@ def test_on_before_transform_resource_must_skip_non_policy_templates(self, funct {"Statement2": {"key2": "value2"}}, ] - expected = [{"Statement1": {"key1": "value1"}}, {"regular policy": "something"}, {"Statement2": {"key2": "value2"}}] + expected = [ + {"Statement1": {"key1": "value1"}}, + {"regular policy": "something"}, + {"Statement2": {"key2": "value2"}}, + ] self.plugin.on_before_transform_resource("logicalId", "resource_type", resource_properties) # This will overwrite the resource_properties input array self.assertEqual(expected, resource_properties["Policies"]) function_policies_obj_mock.get.assert_called_once_with() - self._policy_template_processor_mock.convert.assert_has_calls([ - call("MyTemplate1", {"Param1": "value1"}), - call("MyTemplate2", {"Param2": "value2"}) - ]) - + self._policy_template_processor_mock.convert.assert_has_calls( + [call("MyTemplate1", {"Param1": "value1"}), call("MyTemplate2", {"Param2": "value2"})] + ) @patch("samtranslator.plugins.policies.policy_templates_plugin.FunctionPolicies") def test_on_before_transform_must_raise_on_insufficient_parameter_values(self, function_policies_class_mock): @@ -143,18 +123,10 @@ def test_on_before_transform_must_raise_on_insufficient_parameter_values(self, f function_policies_obj_mock = MagicMock() function_policies_class_mock.return_value = function_policies_obj_mock - template1 = { - "MyTemplate1": { - "Param1": "value1" - } - } - resource_properties = { - "Policies": template1 - } + template1 = {"MyTemplate1": {"Param1": "value1"}} + resource_properties = {"Policies": template1} - policies = [ - PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE) - ] + policies = [PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE)] # Setup to return all the policies function_policies_obj_mock.__len__.return_value = 1 @@ -167,7 +139,7 @@ def test_on_before_transform_must_raise_on_insufficient_parameter_values(self, f self.plugin.on_before_transform_resource("logicalId", "resource_type", resource_properties) # Make sure the input was not changed - self.assertEqual(resource_properties, {"Policies": {"MyTemplate1": { "Param1": "value1"}}}) + self.assertEqual(resource_properties, {"Policies": {"MyTemplate1": {"Param1": "value1"}}}) @patch("samtranslator.plugins.policies.policy_templates_plugin.FunctionPolicies") def test_on_before_transform_must_raise_on_invalid_parameter_values(self, function_policies_class_mock): @@ -178,18 +150,10 @@ def test_on_before_transform_must_raise_on_invalid_parameter_values(self, functi function_policies_obj_mock = MagicMock() function_policies_class_mock.return_value = function_policies_obj_mock - template1 = { - "MyTemplate1": { - "Param1": "value1" - } - } - resource_properties = { - "Policies": template1 - } + template1 = {"MyTemplate1": {"Param1": "value1"}} + resource_properties = {"Policies": template1} - policies = [ - PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE) - ] + policies = [PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE)] # Setup to return all the policies function_policies_obj_mock.__len__.return_value = 1 @@ -201,7 +165,7 @@ def test_on_before_transform_must_raise_on_invalid_parameter_values(self, functi self.plugin.on_before_transform_resource("logicalId", "resource_type", resource_properties) # Make sure the input was not changed - self.assertEqual(resource_properties, {"Policies": {"MyTemplate1": { "Param1": "value1"}}}) + self.assertEqual(resource_properties, {"Policies": {"MyTemplate1": {"Param1": "value1"}}}) @patch("samtranslator.plugins.policies.policy_templates_plugin.FunctionPolicies") def test_on_before_transform_must_bubble_exception(self, function_policies_class_mock): @@ -212,30 +176,22 @@ def test_on_before_transform_must_bubble_exception(self, function_policies_class function_policies_obj_mock = MagicMock() function_policies_class_mock.return_value = function_policies_obj_mock - template1 = { - "MyTemplate1": { - "Param1": "value1" - } - } - resource_properties = { - "Policies": template1 - } + template1 = {"MyTemplate1": {"Param1": "value1"}} + resource_properties = {"Policies": template1} - policies = [ - PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE) - ] + policies = [PolicyEntry(data=template1, type=PolicyTypes.POLICY_TEMPLATE)] # Setup to return all the policies function_policies_obj_mock.__len__.return_value = 1 function_policies_obj_mock.get.return_value = iter(policies) - self._policy_template_processor_mock.convert.side_effect = TypeError('message') + self._policy_template_processor_mock.convert.side_effect = TypeError("message") with self.assertRaises(TypeError): self.plugin.on_before_transform_resource("logicalId", "resource_type", resource_properties) # Make sure the input was not changed - self.assertEqual(resource_properties, {"Policies": {"MyTemplate1": { "Param1": "value1"}}}) + self.assertEqual(resource_properties, {"Policies": {"MyTemplate1": {"Param1": "value1"}}}) def test_on_before_transform_resource_must_skip_unsupported_resources(self): diff --git a/tests/policy_template_processor/test_processor.py b/tests/policy_template_processor/test_processor.py index 83d6c460e0..9319080f44 100644 --- a/tests/policy_template_processor/test_processor.py +++ b/tests/policy_template_processor/test_processor.py @@ -9,13 +9,11 @@ from samtranslator.policy_template_processor.template import Template from samtranslator.policy_template_processor.exceptions import TemplateNotFoundException -class TestPolicyTemplateProcessor(TestCase): +class TestPolicyTemplateProcessor(TestCase): @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") def test_init_must_validate_against_default_schema(self, is_valid_templates_dict_mock): - policy_templates_dict = { - "Templates": {} - } + policy_templates_dict = {"Templates": {}} is_valid_templates_dict_mock.return_value = True @@ -24,9 +22,7 @@ def test_init_must_validate_against_default_schema(self, is_valid_templates_dict @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") def test_init_must_validate_against_input_schema(self, is_valid_templates_dict_mock): - policy_templates_dict = { - "Templates": {} - } + policy_templates_dict = {"Templates": {}} schema = "something" is_valid_templates_dict_mock.return_value = True @@ -36,9 +32,7 @@ def test_init_must_validate_against_input_schema(self, is_valid_templates_dict_m @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") def test_init_must_raise_on_invalid_template(self, is_valid_templates_dict_mock): - policy_templates_dict = { - "Templates": {} - } + policy_templates_dict = {"Templates": {}} is_valid_templates_dict_mock.side_effect = ValueError() with self.assertRaises(ValueError): @@ -46,13 +40,10 @@ def test_init_must_raise_on_invalid_template(self, is_valid_templates_dict_mock) @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") - def test_init_must_convert_template_value_dict_to_object(self, template_from_dict_mock, is_valid_templates_dict_mock): - policy_templates_dict = { - "Templates": { - "key1": "value1", - "key2": "value2" - } - } + def test_init_must_convert_template_value_dict_to_object( + self, template_from_dict_mock, is_valid_templates_dict_mock + ): + policy_templates_dict = {"Templates": {"key1": "value1", "key2": "value2"}} is_valid_templates_dict_mock.return_value = True template_from_dict_mock.return_value = "Something" @@ -70,12 +61,7 @@ def test_init_must_convert_template_value_dict_to_object(self, template_from_dic @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") def test_has_method_must_work_for_known_template_names(self, template_from_dict_mock, is_valid_templates_dict_mock): - policy_templates_dict = { - "Templates": { - "key1": "value1", - "key2": "value2" - } - } + policy_templates_dict = {"Templates": {"key1": "value1", "key2": "value2"}} processor = PolicyTemplatesProcessor(policy_templates_dict) @@ -85,12 +71,7 @@ def test_has_method_must_work_for_known_template_names(self, template_from_dict_ @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") def test_has_method_must_work_for_not_known_template_names(self, template_from_dict_mock, is_valid_templates_dict): - policy_templates_dict = { - "Templates": { - "key1": "value1", - "key2": "value2" - } - } + policy_templates_dict = {"Templates": {"key1": "value1", "key2": "value2"}} processor = PolicyTemplatesProcessor(policy_templates_dict) @@ -98,12 +79,10 @@ def test_has_method_must_work_for_not_known_template_names(self, template_from_d @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") - def test_get_method_must_return_template_object_for_known_template_names(self, template_from_dict_mock, is_valid_templates_dict): - policy_templates_dict = { - "Templates": { - "key1": "value1" - } - } + def test_get_method_must_return_template_object_for_known_template_names( + self, template_from_dict_mock, is_valid_templates_dict + ): + policy_templates_dict = {"Templates": {"key1": "value1"}} template_obj = "some value" template_from_dict_mock.return_value = template_obj @@ -114,12 +93,10 @@ def test_get_method_must_return_template_object_for_known_template_names(self, t @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") - def test_get_method_must_return_none_for_unknown_template_names(self, template_from_dict_mock, is_valid_templates_dict): - policy_templates_dict = { - "Templates": { - "key1": "value1" - } - } + def test_get_method_must_return_none_for_unknown_template_names( + self, template_from_dict_mock, is_valid_templates_dict + ): + policy_templates_dict = {"Templates": {"key1": "value1"}} template_obj = "some value" template_from_dict_mock.return_value = template_obj @@ -131,14 +108,8 @@ def test_get_method_must_return_none_for_unknown_template_names(self, template_f @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") def test_convert_must_work_for_known_template_names(self, template_from_dict_mock, is_valid_templates_dict): - policy_templates_dict = { - "Templates": { - "key1": "value1" - } - } - parameter_values = { - "a": "b" - } + policy_templates_dict = {"Templates": {"key1": "value1"}} + parameter_values = {"a": "b"} template_obj_mock = Mock() template_from_dict_mock.return_value = template_obj_mock @@ -155,14 +126,8 @@ def test_convert_must_work_for_known_template_names(self, template_from_dict_moc @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") def test_convert_must_raise_if_template_name_not_found(self, template_from_dict_mock, is_valid_templates_dict): - policy_templates_dict = { - "Templates": { - "key1": "value1" - } - } - parameter_values = { - "a": "b" - } + policy_templates_dict = {"Templates": {"key1": "value1"}} + parameter_values = {"a": "b"} processor = PolicyTemplatesProcessor(policy_templates_dict) @@ -172,14 +137,8 @@ def test_convert_must_raise_if_template_name_not_found(self, template_from_dict_ @patch.object(PolicyTemplatesProcessor, "_is_valid_templates_dict") @patch.object(Template, "from_dict") def test_convert_must_bubble_exceptions(self, template_from_dict_mock, is_valid_templates_dict): - policy_templates_dict = { - "Templates": { - "key1": "value1" - } - } - parameter_values = { - "a": "b" - } + policy_templates_dict = {"Templates": {"key1": "value1"}} + parameter_values = {"a": "b"} template_obj_mock = Mock() template_from_dict_mock.return_value = template_obj_mock @@ -194,9 +153,7 @@ def test_convert_must_bubble_exceptions(self, template_from_dict_mock, is_valid_ @patch.object(jsonschema, "validate") @patch.object(PolicyTemplatesProcessor, "_read_schema") def test_is_valid_templates_dict_must_use_default_schema(self, read_schema_mock, jsonschema_validate_mock): - policy_templates_dict = { - "key": "value" - } + policy_templates_dict = {"key": "value"} schema = "some schema" read_schema_mock.return_value = schema @@ -211,9 +168,7 @@ def test_is_valid_templates_dict_must_use_default_schema(self, read_schema_mock, @patch.object(jsonschema, "validate") @patch.object(PolicyTemplatesProcessor, "_read_schema") def test_is_valid_templates_dict_must_use_input_schema(self, read_schema_mock, jsonschema_validate_mock): - policy_templates_dict = { - "key": "value" - } + policy_templates_dict = {"key": "value"} schema = "some schema" jsonschema_validate_mock.return_value = True @@ -227,9 +182,7 @@ def test_is_valid_templates_dict_must_use_input_schema(self, read_schema_mock, j @patch.object(jsonschema, "validate") @patch.object(PolicyTemplatesProcessor, "_read_schema") def test_is_valid_templates_dict_must_raise_for_invalid_input(self, read_schema_mock, jsonschema_validate_mock): - policy_templates_dict = { - "key": "value" - } + policy_templates_dict = {"key": "value"} schema = "some schema" exception_msg = "exception" @@ -246,9 +199,7 @@ def test_is_valid_templates_dict_must_raise_for_invalid_input(self, read_schema_ @patch.object(jsonschema, "validate") @patch.object(PolicyTemplatesProcessor, "_read_schema") def test_is_valid_templates_dict_must_bubble_unhandled_exceptions(self, read_schema_mock, jsonschema_validate_mock): - policy_templates_dict = { - "key": "value" - } + policy_templates_dict = {"key": "value"} schema = "some schema" exception_msg = "exception" @@ -259,7 +210,7 @@ def test_is_valid_templates_dict_must_bubble_unhandled_exceptions(self, read_sch with self.assertRaises(TypeError): PolicyTemplatesProcessor._is_valid_templates_dict(policy_templates_dict) - @patch.object(json, 'loads') + @patch.object(json, "loads") def test_read_json_must_read_from_file(self, json_loads_mock): filepath = "some file" @@ -275,7 +226,7 @@ def test_read_json_must_read_from_file(self, json_loads_mock): open_mock.assert_called_once_with(filepath, "r") self.assertEqual(1, json_loads_mock.call_count) - @patch.object(PolicyTemplatesProcessor, '_read_json') + @patch.object(PolicyTemplatesProcessor, "_read_json") def test_read_schema_must_use_default_schema_location(self, _read_file_mock): expected = "something" _read_file_mock.return_value = expected @@ -284,7 +235,7 @@ def test_read_schema_must_use_default_schema_location(self, _read_file_mock): self.assertEqual(result, expected) _read_file_mock.assert_called_once_with(PolicyTemplatesProcessor.SCHEMA_LOCATION) - @patch.object(PolicyTemplatesProcessor, '_read_json') + @patch.object(PolicyTemplatesProcessor, "_read_json") def test_get_default_policy_template_json_must_work(self, _read_file_mock): expected = "something" _read_file_mock.return_value = expected diff --git a/tests/policy_template_processor/test_schema.py b/tests/policy_template_processor/test_schema.py index e803217725..b19ee2496d 100644 --- a/tests/policy_template_processor/test_schema.py +++ b/tests/policy_template_processor/test_schema.py @@ -3,6 +3,7 @@ from parameterized import parameterized from unittest import TestCase + class TestTemplates(object): """ Write your test cases here as different variables that store the entire template file. Start the variable with @@ -12,33 +13,17 @@ class TestTemplates(object): Just trying out a BDDish test runner """ - succeed_with_no_template = { - "Version": "1.2.3", - "Templates": {} - } + succeed_with_no_template = {"Version": "1.2.3", "Templates": {}} succeed_with_single_statement = { "Version": "1.0.0", "Templates": { "ManagedPolicy1Policy": { - "Description": "Very first managed policy", - - "Parameters": { - "Param1": { - "Description": "some desc" - }, - "Param2": { - "Description": "some desc" - } - }, - "Definition": { - "Statement": [{ - "key": "value" - }] - } + "Parameters": {"Param1": {"Description": "some desc"}, "Param2": {"Description": "some desc"}}, + "Definition": {"Statement": [{"key": "value"}]}, } - } + }, } succeed_with_multiple_templates = { @@ -46,147 +31,67 @@ class TestTemplates(object): "Templates": { "ManagedPolicy1Policy": { "Description": "Very first managed policy", - "Parameters": { - "Param1": { - "Description": "some desc" - }, - "Param2": { - "Description": "some desc" - } - }, - "Definition": { - "Statement": [ - { - "key": "value" - }, - { - "otherkey": "othervalue" - } - ] - } + "Parameters": {"Param1": {"Description": "some desc"}, "Param2": {"Description": "some desc"}}, + "Definition": {"Statement": [{"key": "value"}, {"otherkey": "othervalue"}]}, }, - "ManagedPolicy2Policy": { "Description": "Second managed policy", - "Parameters": { - "1stParam": { - "Description": "some desc" - } - }, - "Definition": { - "Statement": [{ - "key": "value" - }] - } - } - } + "Parameters": {"1stParam": {"Description": "some desc"}}, + "Definition": {"Statement": [{"key": "value"}]}, + }, + }, } - fail_for_template_with_no_version = { - "Templates": {} - } + fail_for_template_with_no_version = {"Templates": {}} - fail_without_template = { - "Version": "1.2.3" - } + fail_without_template = {"Version": "1.2.3"} - fail_with_non_semantic_version = { - "Version": "version1", - "Templates": {} - } + fail_with_non_semantic_version = {"Version": "version1", "Templates": {}} fail_without_all_three_parts_of_semver = { # Yes! you need all three parts "Version": "1.0", - "Templates": {} - } - - fail_with_additional_properties = { - "Version": "1.2.3", "Templates": {}, - "Something": "value" } + fail_with_additional_properties = {"Version": "1.2.3", "Templates": {}, "Something": "value"} + fail_for_bad_template_name = { "Version": "1.0.0", "Templates": { # Names must have the suffix "Policy" - "ThisNameDoesNotHaveTheSuffix": { - "Parameters": {}, - "Definition": { - "Statement": [{ - "key": "value" - }] - } - } - } + "ThisNameDoesNotHaveTheSuffix": {"Parameters": {}, "Definition": {"Statement": [{"key": "value"}]}} + }, } fail_for_template_without_parameters = { "Version": "1.0.0", - "Templates": { - "MyTemplate": { - "Definition": { - "Statement": [{ - "key": "value" - }] - } - } - } + "Templates": {"MyTemplate": {"Definition": {"Statement": [{"key": "value"}]}}}, } - fail_for_template_without_definition = { - "Version": "1.0.0", - "Templates": { - "MyTemplate": { - "Parameters": {}, - } - } - } + fail_for_template_without_definition = {"Version": "1.0.0", "Templates": {"MyTemplate": {"Parameters": {}}}} fail_for_template_with_empty_definition = { "Version": "1.0.0", - "Templates": { - "MyTemplate": { - "Parameters": {}, - "Definition": {} - } - } + "Templates": {"MyTemplate": {"Parameters": {}, "Definition": {}}}, } fail_for_parameter_with_no_description = { "Version": "1.0.0", - "Templates": { - "MyTemplate": { - "Parameters": { - "Param1": {} - }, - "Definition": { - "Statement": [{ - "key": "value" - }] - } - } - } + "Templates": {"MyTemplate": {"Parameters": {"Param1": {}}, "Definition": {"Statement": [{"key": "value"}]}}}, } fail_for_parameter_name_with_underscores = { "Version": "1.0.0", "Templates": { "MyTemplate": { - "Parameters": { # Underscores are not allowed. Following CFN naming convention here "invalid_name": {"Description": "value"} }, - - "Definition": { - "Statement": [{ - "key": "value" - }] - } + "Definition": {"Statement": [{"key": "value"}]}, } - } + }, } fail_for_definition_is_an_array = { @@ -194,17 +99,13 @@ class TestTemplates(object): "Templates": { "MyTemplate": { "Parameters": {"Param1": {"Description": "value"}}, - # Definition must be a direct statement object. This is not allowed - "Definition": [{ - "Statement": [{ - "key": "value" - }] - }] + "Definition": [{"Statement": [{"key": "value"}]}], } - } + }, } + class TestPolicyTemplateSchema(TestCase): """ Some basic test cases to validate that the JSON Schema representing policy templates actually work as intended diff --git a/tests/policy_template_processor/test_template.py b/tests/policy_template_processor/test_template.py index 7754ac3891..9753b3539a 100644 --- a/tests/policy_template_processor/test_template.py +++ b/tests/policy_template_processor/test_template.py @@ -4,8 +4,8 @@ from samtranslator.policy_template_processor.template import Template from samtranslator.policy_template_processor.exceptions import InvalidParameterValues, InsufficientParameterValues -class TestTemplateObject(TestCase): +class TestTemplateObject(TestCase): @patch.object(Template, "check_parameters_exist") def test_init_must_check_for_existence_of_all_parameters(self, check_parameters_exist_mock): @@ -26,10 +26,7 @@ def test_from_dict_must_return_object(self, check_parameters_exist_mock): parameters = {"A": "B"} template_definition = {"key": "value"} - template_dict = { - "Parameters": parameters, - "Definition": template_definition - } + template_dict = {"Parameters": parameters, "Definition": template_definition} template = Template.from_dict(template_name, template_dict) @@ -43,13 +40,11 @@ def test_from_dict_must_work_when_parameters_is_absent(self, check_parameters_ex template_name = "template_name" template_definition = {"key": "value"} - template_dict = { - "Definition": template_definition - } + template_dict = {"Definition": template_definition} template = Template.from_dict(template_name, template_dict) - self.assertEqual(template.parameters, {}) # Defaults to {} + self.assertEqual(template.parameters, {}) # Defaults to {} self.assertEqual(template.definition, template_definition) @patch.object(Template, "check_parameters_exist") @@ -57,23 +52,16 @@ def test_from_dict_must_work_when_template_definition_is_absent(self, check_para template_name = "template_name" parameters = {"key": "value"} - template_dict = { - "Parameters": parameters - } + template_dict = {"Parameters": parameters} template = Template.from_dict(template_name, template_dict) self.assertEqual(template.parameters, parameters) - self.assertEqual(template.definition, {}) # Defaults to {} + self.assertEqual(template.definition, {}) # Defaults to {} def test_missing_parameter_values_must_work_when_input_has_less_keys(self): - template_parameters = { - "param1": {"Description": "foo"}, - "param2": {"Description": "bar"} - } - parameter_values = { - "param1": "value1" - } + template_parameters = {"param1": {"Description": "foo"}, "param2": {"Description": "bar"}} + parameter_values = {"param1": "value1"} expected = ["param2"] template = Template("name", template_parameters, {}) @@ -82,14 +70,8 @@ def test_missing_parameter_values_must_work_when_input_has_less_keys(self): self.assertEqual(expected, result) def test_missing_parameter_values_must_work_when_input_has_all_keys(self): - template_parameters = { - "param1": {"Description": "foo"}, - "param2": {"Description": "bar"} - } - parameter_values = { - "param1": "value1", - "param2": "value3" - } + template_parameters = {"param1": {"Description": "foo"}, "param2": {"Description": "bar"}} + parameter_values = {"param1": "value1", "param2": "value3"} expected = [] template = Template("name", template_parameters, {}) @@ -98,16 +80,9 @@ def test_missing_parameter_values_must_work_when_input_has_all_keys(self): self.assertEqual(expected, result) def test_missing_parameter_values_must_work_when_input_has_more_keys(self): - template_parameters = { - "param1": {"Description": "foo"}, - "param2": {"Description": "bar"} - } - parameter_values = { - "param1": "value1", - "param2": "value2", - "newparam": "new value" - } - expected = [] # We do a set-difference. So new keys won't make it here + template_parameters = {"param1": {"Description": "foo"}, "param2": {"Description": "bar"}} + parameter_values = {"param1": "value1", "param2": "value2", "newparam": "new value"} + expected = [] # We do a set-difference. So new keys won't make it here template = Template("name", template_parameters, {}) result = template.missing_parameter_values(parameter_values) @@ -115,11 +90,8 @@ def test_missing_parameter_values_must_work_when_input_has_more_keys(self): self.assertEqual(expected, result) def test_missing_parameter_values_must_raise_on_invalid_input(self): - template_parameters = { - "param1": {"Description": "foo"}, - "param2": {"Description": "bar"} - } - parameter_values = [1,2,3] + template_parameters = {"param1": {"Description": "foo"}, "param2": {"Description": "bar"}} + parameter_values = [1, 2, 3] template = Template("name", template_parameters, {}) @@ -138,7 +110,7 @@ def test_is_valid_parameter_values_must_fail_for_none_value(self): def test_is_valid_parameter_values_must_fail_for_non_dict(self): - parameter_values = [1,2,3] + parameter_values = [1, 2, 3] self.assertFalse(Template._is_valid_parameter_values(parameter_values)) @patch("samtranslator.policy_template_processor.template.IntrinsicsResolver") diff --git a/tests/sdk/test_parameter.py b/tests/sdk/test_parameter.py index 10c749a658..d69aa77b2b 100644 --- a/tests/sdk/test_parameter.py +++ b/tests/sdk/test_parameter.py @@ -5,164 +5,98 @@ from samtranslator.sdk.parameter import SamParameterValues from mock import patch -class TestSAMParameterValues(TestCase): +class TestSAMParameterValues(TestCase): def test_add_default_parameter_values_must_merge(self): - parameter_values = { - "Param1": "value1" - } - - sam_template = { - "Parameters": { - "Param2": { - "Type": "String", - "Default": "template default" - } - } - } - - expected = { - "Param1": "value1", - "Param2": "template default" - } + parameter_values = {"Param1": "value1"} + + sam_template = {"Parameters": {"Param2": {"Type": "String", "Default": "template default"}}} + + expected = {"Param1": "value1", "Param2": "template default"} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_default_parameter_values(sam_template) self.assertEqual(expected, sam_parameter_values.parameter_values) def test_add_default_parameter_values_must_override_user_specified_values(self): - parameter_values = { - "Param1": "value1" - } - - sam_template = { - "Parameters": { - "Param1": { - "Type": "String", - "Default": "template default" - } - } - } - - expected = { - "Param1": "value1" - } + parameter_values = {"Param1": "value1"} - sam_parameter_values = SamParameterValues(parameter_values) - sam_parameter_values.add_default_parameter_values(sam_template) - self.assertEqual(expected, sam_parameter_values.parameter_values) + sam_template = {"Parameters": {"Param1": {"Type": "String", "Default": "template default"}}} - def test_add_default_parameter_values_must_skip_params_without_defaults(self): - parameter_values = { - "Param1": "value1" - } - - sam_template = { - "Parameters": { - "Param1": { - "Type": "String" - }, - "Param2": { - "Type": "String" - } - } - } - - expected = { - "Param1": "value1" - } + expected = {"Param1": "value1"} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_default_parameter_values(sam_template) self.assertEqual(expected, sam_parameter_values.parameter_values) + def test_add_default_parameter_values_must_skip_params_without_defaults(self): + parameter_values = {"Param1": "value1"} - @parameterized.expand([ - # Array - param(["1", "2"]), + sam_template = {"Parameters": {"Param1": {"Type": "String"}, "Param2": {"Type": "String"}}} - # String - param("something"), + expected = {"Param1": "value1"} - # Some other non-parameter looking dictionary - param({"Param1": {"Foo": "Bar"}}), + sam_parameter_values = SamParameterValues(parameter_values) + sam_parameter_values.add_default_parameter_values(sam_template) + self.assertEqual(expected, sam_parameter_values.parameter_values) - param(None) - ]) + @parameterized.expand( + [ + # Array + param(["1", "2"]), + # String + param("something"), + # Some other non-parameter looking dictionary + param({"Param1": {"Foo": "Bar"}}), + param(None), + ] + ) def test_add_default_parameter_values_must_ignore_invalid_template_parameters(self, template_parameters): - parameter_values = { - "Param1": "value1" - } + parameter_values = {"Param1": "value1"} - expected = { - "Param1": "value1" - } + expected = {"Param1": "value1"} - sam_template = { - "Parameters": template_parameters - } + sam_template = {"Parameters": template_parameters} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_default_parameter_values(sam_template) self.assertEqual(expected, sam_parameter_values.parameter_values) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_add_pseudo_parameter_values_aws_region(self): - parameter_values = { - "Param1": "value1" - } + parameter_values = {"Param1": "value1"} - expected = { - "Param1": "value1", - "AWS::Region": "ap-southeast-1", - "AWS::Partition": "aws" - } + expected = {"Param1": "value1", "AWS::Region": "ap-southeast-1", "AWS::Partition": "aws"} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_pseudo_parameter_values() self.assertEqual(expected, sam_parameter_values.parameter_values) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_add_pseudo_parameter_values_aws_region_not_override(self): - parameter_values = { - "AWS::Region": "value1" - } + parameter_values = {"AWS::Region": "value1"} - expected = { - "AWS::Region": "value1", - "AWS::Partition": "aws" - } + expected = {"AWS::Region": "value1", "AWS::Partition": "aws"} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_pseudo_parameter_values() self.assertEqual(expected, sam_parameter_values.parameter_values) - @patch('boto3.session.Session.region_name', 'us-gov-west-1') + @patch("boto3.session.Session.region_name", "us-gov-west-1") def test_add_pseudo_parameter_values_aws_partition(self): - parameter_values = { - "Param1": "value1" - } + parameter_values = {"Param1": "value1"} - expected = { - "Param1": "value1", - "AWS::Region": "us-gov-west-1", - "AWS::Partition": "aws-us-gov" - } + expected = {"Param1": "value1", "AWS::Region": "us-gov-west-1", "AWS::Partition": "aws-us-gov"} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_pseudo_parameter_values() self.assertEqual(expected, sam_parameter_values.parameter_values) - @patch('boto3.session.Session.region_name', 'us-gov-west-1') + @patch("boto3.session.Session.region_name", "us-gov-west-1") def test_add_pseudo_parameter_values_aws_partition_not_override(self): - parameter_values = { - "AWS::Partition": "aws" - } - - expected = { - "AWS::Partition": "aws", - "AWS::Region": "us-gov-west-1" - } + parameter_values = {"AWS::Partition": "aws"} + + expected = {"AWS::Partition": "aws", "AWS::Region": "us-gov-west-1"} sam_parameter_values = SamParameterValues(parameter_values) sam_parameter_values.add_pseudo_parameter_values() diff --git a/tests/sdk/test_resource.py b/tests/sdk/test_resource.py index 3514bc9a4a..ee97e8394b 100644 --- a/tests/sdk/test_resource.py +++ b/tests/sdk/test_resource.py @@ -2,37 +2,17 @@ from samtranslator.sdk.resource import SamResource, SamResourceType -class TestSamResource(TestCase): +class TestSamResource(TestCase): def setUp(self): - self.function_dict = { - "Type": "AWS::Serverless::Function", - "Properties": { - "a": "b" - } - } + self.function_dict = {"Type": "AWS::Serverless::Function", "Properties": {"a": "b"}} - self.api_dict = { - "Type": "AWS::Serverless::Api", - "Properties": { - "a": "b" - } - } + self.api_dict = {"Type": "AWS::Serverless::Api", "Properties": {"a": "b"}} - self.simple_table_dict = { - "Type": "AWS::Serverless::SimpleTable", - "Properties": { - "a": "b" - } - } + self.simple_table_dict = {"Type": "AWS::Serverless::SimpleTable", "Properties": {"a": "b"}} def test_init_must_extract_type_and_properties(self): - resource_dict = { - "Type": "foo", - "Properties": { - "a": "b" - } - } + resource_dict = {"Type": "foo", "Properties": {"a": "b"}} resource = SamResource(resource_dict) self.assertEqual(resource.type, "foo") @@ -58,12 +38,7 @@ def test_valid_must_not_work_with_resource_without_type(self): self.assertFalse(SamResource({"a": "b"}).valid()) def test_to_dict_must_update_type_and_properties(self): - resource_dict = { - "Type": "AWS::Serverless::Function", - "Properties": { - "a": "b" - } - } + resource_dict = {"Type": "AWS::Serverless::Function", "Properties": {"a": "b"}} resource = SamResource(resource_dict) resource.type = "AWS::Serverless::Api" @@ -76,7 +51,6 @@ def test_to_dict_must_update_type_and_properties(self): class TestSamResourceTypeEnum(TestCase): - def test_contains_sam_resources(self): self.assertEqual(SamResourceType.Function.value, "AWS::Serverless::Function") self.assertEqual(SamResourceType.Api.value, "AWS::Serverless::Api") diff --git a/tests/sdk/test_template.py b/tests/sdk/test_template.py index 2e1f82625d..161833b2aa 100644 --- a/tests/sdk/test_template.py +++ b/tests/sdk/test_template.py @@ -4,36 +4,20 @@ from samtranslator.sdk.template import SamTemplate from samtranslator.sdk.resource import SamResource -class TestSamTemplate(TestCase): +class TestSamTemplate(TestCase): def setUp(self): self.template_dict = { - "Properties": { - "c": "d" - }, - "Metadata": { - "a": "b" - }, - "Resources": { - "Function1": { - "Type": "AWS::Serverless::Function", - "DependsOn": "SomeOtherResource" - }, - "Function2": { - "Type": "AWS::Serverless::Function", - "a": "b" - }, - "Api": { - "Type": "AWS::Serverless::Api" - }, - "Layer": { - "Type": "AWS::Serverless::LayerVersion" - }, - "NonSam": { - "Type": "AWS::Lambda::Function" - } - } + "Properties": {"c": "d"}, + "Metadata": {"a": "b"}, + "Resources": { + "Function1": {"Type": "AWS::Serverless::Function", "DependsOn": "SomeOtherResource"}, + "Function2": {"Type": "AWS::Serverless::Function", "a": "b"}, + "Api": {"Type": "AWS::Serverless::Api"}, + "Layer": {"Type": "AWS::Serverless::LayerVersion"}, + "NonSam": {"Type": "AWS::Lambda::Function"}, + }, } def test_iterate_must_yield_sam_resources_only(self): @@ -66,9 +50,7 @@ def test_iterate_must_filter_by_layers_resource_type(self): template = SamTemplate(self.template_dict) type = "AWS::Serverless::LayerVersion" - expected = [ - ("Layer", {"Type": "AWS::Serverless::LayerVersion", "Properties": {}}), - ] + expected = [("Layer", {"Type": "AWS::Serverless::LayerVersion", "Properties": {}})] actual = [(id, resource.to_dict()) for id, resource in template.iterate(type)] self.assertEqual(expected, actual) @@ -106,11 +88,7 @@ def test_set_must_work_with_sam_resource_input(self): self.assertEqual(self.template_dict["Resources"].get("NewResource"), {"Type": "something"}) def test_get_must_return_resource(self): - expected = { - "Type": "AWS::Serverless::Function", - "DependsOn": "SomeOtherResource", - "Properties": {} - } + expected = {"Type": "AWS::Serverless::Function", "DependsOn": "SomeOtherResource", "Properties": {}} template = SamTemplate(self.template_dict) diff --git a/tests/swagger/test_swagger.py b/tests/swagger/test_swagger.py index c85cbbeeec..c4b409d62c 100644 --- a/tests/swagger/test_swagger.py +++ b/tests/swagger/test_swagger.py @@ -10,26 +10,20 @@ from tests.translator.test_translator import deep_sort_lists _X_INTEGRATION = "x-amazon-apigateway-integration" -_X_ANY_METHOD = 'x-amazon-apigateway-any-method' -_X_POLICY = 'x-amazon-apigateway-policy' +_X_ANY_METHOD = "x-amazon-apigateway-any-method" +_X_POLICY = "x-amazon-apigateway-policy" _ALLOW_CREDENTALS_TRUE = "'true'" -class TestSwaggerEditor_init(TestCase): +class TestSwaggerEditor_init(TestCase): def test_must_raise_on_invalid_swagger(self): - invalid_swagger = {"paths": {}} # Missing "Swagger" keyword + invalid_swagger = {"paths": {}} # Missing "Swagger" keyword with self.assertRaises(ValueError): SwaggerEditor(invalid_swagger) def test_must_succeed_on_valid_swagger(self): - valid_swagger = { - "swagger": "2.0", - "paths": { - "/foo": {}, - "/bar": {} - } - } + valid_swagger = {"swagger": "2.0", "paths": {"/foo": {}, "/bar": {}}} editor = SwaggerEditor(valid_swagger) self.assertIsNotNone(editor) @@ -37,37 +31,19 @@ def test_must_succeed_on_valid_swagger(self): self.assertEqual(editor.paths, {"/foo": {}, "/bar": {}}) def test_must_fail_on_invalid_openapi_version(self): - invalid_swagger = { - "openapi": "2.3.0", - "paths": { - "/foo": {}, - "/bar": {} - } - } + invalid_swagger = {"openapi": "2.3.0", "paths": {"/foo": {}, "/bar": {}}} with self.assertRaises(ValueError): SwaggerEditor(invalid_swagger) def test_must_fail_on_invalid_openapi_version_2(self): - invalid_swagger = { - "openapi": "3.1.1.1", - "paths": { - "/foo": {}, - "/bar": {} - } - } + invalid_swagger = {"openapi": "3.1.1.1", "paths": {"/foo": {}, "/bar": {}}} with self.assertRaises(ValueError): SwaggerEditor(invalid_swagger) def test_must_succeed_on_valid_openapi3(self): - valid_swagger = { - "openapi": "3.0.1", - "paths": { - "/foo": {}, - "/bar": {} - } - } + valid_swagger = {"openapi": "3.0.1", "paths": {"/foo": {}, "/bar": {}}} editor = SwaggerEditor(valid_swagger) self.assertIsNotNone(editor) @@ -76,21 +52,14 @@ def test_must_succeed_on_valid_openapi3(self): class TestSwaggerEditor_has_path(TestCase): - def setUp(self): self.swagger = { "swagger": "2.0", "paths": { - "/foo": { - "get": {}, - "somemethod": {} - }, - "/bar": { - "post": {}, - _X_ANY_METHOD: {} - }, - "badpath": "string value" - } + "/foo": {"get": {}, "somemethod": {}}, + "/bar": {"post": {}, _X_ANY_METHOD: {}}, + "badpath": "string value", + }, } self.editor = SwaggerEditor(self.swagger) @@ -114,7 +83,7 @@ def test_must_work_with_any_method(self): Method name "ANY" is special. It must be converted to the x-amazon style value before search """ self.assertTrue(self.editor.has_path("/bar", "any")) - self.assertTrue(self.editor.has_path("/bar", "AnY")) # Case insensitive + self.assertTrue(self.editor.has_path("/bar", "AnY")) # Case insensitive self.assertTrue(self.editor.has_path("/bar", _X_ANY_METHOD)) self.assertFalse(self.editor.has_path("/foo", "any")) @@ -136,48 +105,19 @@ def test_must_not_fail_on_bad_path(self): class TestSwaggerEditor_has_integration(TestCase): - def setUp(self): self.swagger = { "swagger": "2.0", "paths": { "/foo": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - }, - "post": { - "Fn::If": [ - "Condition", - { - _X_INTEGRATION: { - "a": "b" - } - }, - {"Ref": "AWS::NoValue"} - ] - }, - "delete": { - "Fn::If": [ - "Condition", - {"Ref": "AWS::NoValue"}, - { - _X_INTEGRATION: { - "a": "b" - } - } - ] - }, - "somemethod": { - "foo": "value", - }, - "emptyintegration": { - _X_INTEGRATION: {} - }, - "badmethod": "string value" - }, - } + "get": {_X_INTEGRATION: {"a": "b"}}, + "post": {"Fn::If": ["Condition", {_X_INTEGRATION: {"a": "b"}}, {"Ref": "AWS::NoValue"}]}, + "delete": {"Fn::If": ["Condition", {"Ref": "AWS::NoValue"}, {_X_INTEGRATION: {"a": "b"}}]}, + "somemethod": {"foo": "value"}, + "emptyintegration": {_X_INTEGRATION: {}}, + "badmethod": "string value", + } + }, } self.editor = SwaggerEditor(self.swagger) @@ -202,33 +142,28 @@ def test_must_handle_bad_value_for_method(self): class TestSwaggerEditor_add_path(TestCase): - def setUp(self): self.original_swagger = { "swagger": "2.0", - "paths": { - "/foo": { - "get": {"a": "b"} - }, - "/bar": {}, - "/badpath": "string value" - } + "paths": {"/foo": {"get": {"a": "b"}}, "/bar": {}, "/badpath": "string value"}, } self.editor = SwaggerEditor(self.original_swagger) - @parameterized.expand([ - param("/new", "get", "new path, new method"), - param("/foo", "new method", "existing path, new method"), - param("/bar", "get", "existing path, new method"), - ]) + @parameterized.expand( + [ + param("/new", "get", "new path, new method"), + param("/foo", "new method", "existing path, new method"), + param("/bar", "get", "existing path, new method"), + ] + ) def test_must_add_new_path_and_method(self, path, method, case): self.assertFalse(self.editor.has_path(path, method)) self.editor.add_path(path, method) - self.assertTrue(self.editor.has_path(path, method), "must add for "+case) + self.assertTrue(self.editor.has_path(path, method), "must add for " + case) self.assertEqual(self.editor.swagger["paths"][path][method], {}) def test_must_raise_non_dict_path_values(self): @@ -255,28 +190,14 @@ def test_must_skip_existing_path(self): class TestSwaggerEditor_add_lambda_integration(TestCase): - def setUp(self): self.original_swagger = { "swagger": "2.0", "paths": { - "/foo": { - "post": { - "a": [1, 2, "b"], - "responses": { - "something": "is already here" - } - } - }, - "/bar": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - } - }, - } + "/foo": {"post": {"a": [1, 2, "b"], "responses": {"something": "is already here"}}}, + "/bar": {"get": {_X_INTEGRATION: {"a": "b"}}}, + }, } self.editor = SwaggerEditor(self.original_swagger) @@ -287,11 +208,7 @@ def test_must_add_new_integration_to_new_path(self): integration_uri = "something" expected = { "responses": {}, - _X_INTEGRATION: { - "type": "aws_proxy", - "httpMethod": "POST", - "uri": integration_uri - } + _X_INTEGRATION: {"type": "aws_proxy", "httpMethod": "POST", "uri": integration_uri}, } self.editor.add_lambda_integration(path, method, integration_uri) @@ -313,20 +230,10 @@ def test_must_add_new_integration_with_conditions_to_new_path(self): _X_INTEGRATION: { "type": "aws_proxy", "httpMethod": "POST", - "uri": { - "Fn::If": [ - "condition", - integration_uri, - { - "Ref": "AWS::NoValue" - } - ] - } - } + "uri": {"Fn::If": ["condition", integration_uri, {"Ref": "AWS::NoValue"}]}, + }, }, - { - "Ref": "AWS::NoValue" - } + {"Ref": "AWS::NoValue"}, ] } @@ -343,18 +250,10 @@ def test_must_add_new_integration_to_existing_path(self): expected = { # Current values present in the dictionary *MUST* be preserved "a": [1, 2, "b"], - # Responses key must be untouched - "responses": { - "something": "is already here" - }, - + "responses": {"something": "is already here"}, # New values must be added - _X_INTEGRATION: { - "type": "aws_proxy", - "httpMethod": "POST", - "uri": integration_uri - } + _X_INTEGRATION: {"type": "aws_proxy", "httpMethod": "POST", "uri": integration_uri}, } # Just make sure test is working on an existing path @@ -374,46 +273,30 @@ def test_must_add_credentials_to_the_integration(self): path = "/newpath" method = "get" integration_uri = "something" - expected = 'arn:aws:iam::*:user/*' - api_auth_config = { - "DefaultAuthorizer": "AWS_IAM", - "InvokeRole": "CALLER_CREDENTIALS" - } + expected = "arn:aws:iam::*:user/*" + api_auth_config = {"DefaultAuthorizer": "AWS_IAM", "InvokeRole": "CALLER_CREDENTIALS"} self.editor.add_lambda_integration(path, method, integration_uri, None, api_auth_config) - actual = self.editor.swagger["paths"][path][method][_X_INTEGRATION]['credentials'] + actual = self.editor.swagger["paths"][path][method][_X_INTEGRATION]["credentials"] self.assertEqual(expected, actual) def test_must_add_credentials_to_the_integration_overrides(self): path = "/newpath" method = "get" integration_uri = "something" - expected = 'arn:aws:iam::*:role/xxxxxx' - api_auth_config = { - "DefaultAuthorizer": "MyAuth", - } - method_auth_config = { - "Authorizer": "AWS_IAM", - "InvokeRole": "arn:aws:iam::*:role/xxxxxx" - } + expected = "arn:aws:iam::*:role/xxxxxx" + api_auth_config = {"DefaultAuthorizer": "MyAuth"} + method_auth_config = {"Authorizer": "AWS_IAM", "InvokeRole": "arn:aws:iam::*:role/xxxxxx"} self.editor.add_lambda_integration(path, method, integration_uri, method_auth_config, api_auth_config) - actual = self.editor.swagger["paths"][path][method][_X_INTEGRATION]['credentials'] + actual = self.editor.swagger["paths"][path][method][_X_INTEGRATION]["credentials"] self.assertEqual(expected, actual) class TestSwaggerEditor_iter_on_path(TestCase): - def setUp(self): - self.original_swagger = { - "swagger": "2.0", - "paths": { - "/foo": {}, - "/bar": {}, - "/baz": "some value" - } - } + self.original_swagger = {"swagger": "2.0", "paths": {"/foo": {}, "/bar": {}, "/baz": "some value"}} self.editor = SwaggerEditor(self.original_swagger) @@ -426,18 +309,11 @@ def test_must_iterate_on_paths(self): class TestSwaggerEditor_add_cors(TestCase): - def setUp(self): self.original_swagger = { "swagger": "2.0", - "paths": { - "/foo": {}, - "/withoptions": { - "options": {"some": "value"} - }, - "/bad": "some value" - } + "paths": {"/foo": {}, "/withoptions": {"options": {"some": "value"}}, "/bad": "some value"}, } self.editor = SwaggerEditor(self.original_swagger) @@ -457,11 +333,9 @@ def test_must_add_options_to_new_path(self): self.editor.add_cors(path, allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials) self.assertEqual(expected, self.editor.swagger["paths"][path]["options"]) - self.editor._options_method_response_for_cors.assert_called_with(allowed_origins, - allowed_headers, - allowed_methods, - max_age, - options_method_response_allow_credentials) + self.editor._options_method_response_for_cors.assert_called_with( + allowed_origins, allowed_headers, allowed_methods, max_age, options_method_response_allow_credentials + ) def test_must_skip_existing_path(self): path = "/withoptions" @@ -485,7 +359,7 @@ def test_must_fail_for_invalid_allowed_origin(self): def test_must_work_for_optional_allowed_headers(self): allowed_origins = "origins" - allowed_headers = None # No Value + allowed_headers = None # No Value allowed_methods = "methods" max_age = 60 allow_credentials = True @@ -501,11 +375,9 @@ def test_must_work_for_optional_allowed_headers(self): self.assertEqual(expected, self.editor.swagger["paths"][path]["options"]) - self.editor._options_method_response_for_cors.assert_called_with(allowed_origins, - allowed_headers, - allowed_methods, - max_age, - options_method_response_allow_credentials) + self.editor._options_method_response_for_cors.assert_called_with( + allowed_origins, allowed_headers, allowed_methods, max_age, options_method_response_allow_credentials + ) def test_must_make_default_value_with_optional_allowed_methods(self): @@ -531,13 +403,15 @@ def test_must_make_default_value_with_optional_allowed_methods(self): self.assertEqual(expected, self.editor.swagger["paths"][path]["options"]) - self.editor._options_method_response_for_cors.assert_called_with(allowed_origins, - allowed_headers, - # Must be called with default value. - # And value must be quoted - default_allow_methods_value_with_quotes, - max_age, - options_method_response_allow_credentials) + self.editor._options_method_response_for_cors.assert_called_with( + allowed_origins, + allowed_headers, + # Must be called with default value. + # And value must be quoted + default_allow_methods_value_with_quotes, + max_age, + options_method_response_allow_credentials, + ) def test_must_accept_none_allow_credentials(self): allowed_origins = "origins" @@ -554,20 +428,17 @@ def test_must_accept_none_allow_credentials(self): self.editor.add_cors(path, allowed_origins, allowed_headers, allowed_methods, max_age, allow_credentials) self.assertEqual(expected, self.editor.swagger["paths"][path]["options"]) - self.editor._options_method_response_for_cors.assert_called_with(allowed_origins, - allowed_headers, - allowed_methods, - max_age, - options_method_response_allow_credentials) + self.editor._options_method_response_for_cors.assert_called_with( + allowed_origins, allowed_headers, allowed_methods, max_age, options_method_response_allow_credentials + ) class TestSwaggerEditor_options_method_response_for_cors(TestCase): - def test_correct_value_is_returned(self): self.maxDiff = None headers = "foo" methods = {"a": "b"} - origins = [1,2,3] + origins = [1, 2, 3] max_age = 60 allow_credentials = True @@ -577,9 +448,7 @@ def test_correct_value_is_returned(self): "produces": ["application/json"], _X_INTEGRATION: { "type": "mock", - "requestTemplates": { - "application/json": "{\n \"statusCode\" : 200\n}\n" - }, + "requestTemplates": {"application/json": '{\n "statusCode" : 200\n}\n'}, "responses": { "default": { "statusCode": "200", @@ -590,43 +459,31 @@ def test_correct_value_is_returned(self): "method.response.header.Access-Control-Allow-Origin": origins, "method.response.header.Access-Control-Max-Age": max_age, }, - "responseTemplates": { - "application/json": "{}\n" - } + "responseTemplates": {"application/json": "{}\n"}, } - } + }, }, "responses": { "200": { "description": "Default response for CORS method", "headers": { - "Access-Control-Allow-Credentials": { - "type": "string" - }, - "Access-Control-Allow-Headers": { - "type": "string" - }, - "Access-Control-Allow-Methods": { - "type": "string" - }, - "Access-Control-Allow-Origin": { - "type": "string" - }, - "Access-Control-Max-Age": { - "type": "integer" - } - } + "Access-Control-Allow-Credentials": {"type": "string"}, + "Access-Control-Allow-Headers": {"type": "string"}, + "Access-Control-Allow-Methods": {"type": "string"}, + "Access-Control-Allow-Origin": {"type": "string"}, + "Access-Control-Max-Age": {"type": "integer"}, + }, } - } + }, } - actual = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors(origins, headers, - methods, max_age, - allow_credentials) + actual = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors( + origins, headers, methods, max_age, allow_credentials + ) self.assertEqual(expected, actual) def test_allow_headers_is_skipped_with_no_value(self): - headers = None # No value + headers = None # No value methods = "methods" origins = "origins" allow_credentials = True @@ -638,19 +495,14 @@ def test_allow_headers_is_skipped_with_no_value(self): } expected_headers = { - "Access-Control-Allow-Credentials": { - "type": "string" - }, - "Access-Control-Allow-Methods": { - "type": "string" - }, - "Access-Control-Allow-Origin": { - "type": "string" - } + "Access-Control-Allow-Credentials": {"type": "string"}, + "Access-Control-Allow-Methods": {"type": "string"}, + "Access-Control-Allow-Origin": {"type": "string"}, } options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors( - origins, headers, methods, allow_credentials=allow_credentials) + origins, headers, methods, allow_credentials=allow_credentials + ) actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"] self.assertEqual(expected, actual) @@ -658,7 +510,7 @@ def test_allow_headers_is_skipped_with_no_value(self): def test_allow_methods_is_skipped_with_no_value(self): headers = "headers" - methods = None # No value + methods = None # No value origins = "origins" allow_credentials = True @@ -669,7 +521,8 @@ def test_allow_methods_is_skipped_with_no_value(self): } options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors( - origins, headers, methods, allow_credentials=allow_credentials) + origins, headers, methods, allow_credentials=allow_credentials + ) actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"] self.assertEqual(expected, actual) @@ -686,7 +539,8 @@ def test_allow_origins_is_not_skipped_with_no_value(self): } options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors( - origins, headers, methods, allow_credentials=allow_credentials) + origins, headers, methods, allow_credentials=allow_credentials + ) actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"] self.assertEqual(expected, actual) @@ -706,7 +560,8 @@ def test_max_age_can_be_set_to_zero(self): } options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors( - origins, headers, methods, max_age, allow_credentials) + origins, headers, methods, max_age, allow_credentials + ) actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"] self.assertEqual(expected, actual) @@ -724,42 +579,36 @@ def test_allow_credentials_is_skipped_with_false_value(self): } options_config = SwaggerEditor(SwaggerEditor.gen_skeleton())._options_method_response_for_cors( - origins, headers, methods, allow_credentials=allow_credentials) + origins, headers, methods, allow_credentials=allow_credentials + ) actual = options_config[_X_INTEGRATION]["responses"]["default"]["responseParameters"] self.assertEqual(expected, actual) class TestSwaggerEditor_make_cors_allowed_methods_for_path(TestCase): - def setUp(self): - self.editor = SwaggerEditor({ - "swagger": "2.0", - "paths": { - "/foo": { - "get": {}, - "POST": {}, - "DeLeTe": {} - }, - "/withany": { - "head": {}, - _X_ANY_METHOD: {} + self.editor = SwaggerEditor( + { + "swagger": "2.0", + "paths": { + "/foo": {"get": {}, "POST": {}, "DeLeTe": {}}, + "/withany": {"head": {}, _X_ANY_METHOD: {}}, + "/nothing": {}, }, - "/nothing": { - } } - }) + ) def test_must_return_all_defined_methods(self): path = "/foo" - expected = "DELETE,GET,OPTIONS,POST" # Result should be sorted alphabetically + expected = "DELETE,GET,OPTIONS,POST" # Result should be sorted alphabetically actual = self.editor._make_cors_allowed_methods_for_path(path) self.assertEqual(expected, actual) def test_must_work_for_any_method(self): path = "/withany" - expected = "DELETE,GET,HEAD,OPTIONS,PATCH,POST,PUT" # Result should be sorted alphabetically + expected = "DELETE,GET,HEAD,OPTIONS,PATCH,POST,PUT" # Result should be sorted alphabetically actual = self.editor._make_cors_allowed_methods_for_path(path) self.assertEqual(expected, actual) @@ -780,187 +629,119 @@ def test_must_skip_non_existent_path(self): class TestSwaggerEditor_normalize_method_name(TestCase): - - @parameterized.expand([ - param("GET", "get", "must lowercase"), - param("PoST", "post", "must lowercase"), - param("ANY", _X_ANY_METHOD, "must convert any method"), - param(None, None, "must skip empty values"), - param({"a": "b"}, {"a": "b"}, "must skip non-string values"), - param([1, 2], [1, 2], "must skip non-string values"), - ]) + @parameterized.expand( + [ + param("GET", "get", "must lowercase"), + param("PoST", "post", "must lowercase"), + param("ANY", _X_ANY_METHOD, "must convert any method"), + param(None, None, "must skip empty values"), + param({"a": "b"}, {"a": "b"}, "must skip non-string values"), + param([1, 2], [1, 2], "must skip non-string values"), + ] + ) def test_must_normalize(self, input, expected, msg): self.assertEqual(expected, SwaggerEditor._normalize_method_name(input), msg) class TestSwaggerEditor_swagger_property(TestCase): - def test_must_return_copy_of_swagger(self): - input = { - "swagger": "2.0", - "paths": {} - } + input = {"swagger": "2.0", "paths": {}} editor = SwaggerEditor(input) - self.assertEqual(input, editor.swagger) # They are equal in content + self.assertEqual(input, editor.swagger) # They are equal in content input["swagger"] = "3" - self.assertEqual("2.0", editor.swagger["swagger"]) # Editor works on a diff copy of input + self.assertEqual("2.0", editor.swagger["swagger"]) # Editor works on a diff copy of input editor.add_path("/foo", "get") self.assertEqual({"/foo": {"get": {}}}, editor.swagger["paths"]) - self.assertEqual({}, input["paths"]) # Editor works on a diff copy of input + self.assertEqual({}, input["paths"]) # Editor works on a diff copy of input class TestSwaggerEditor_is_valid(TestCase): - - @parameterized.expand([ - param(SwaggerEditor.gen_skeleton()), - - # Dict can contain any other unrecognized properties - param({"swagger": "anyvalue", "paths": {}, "foo": "bar", "baz": "bar"}) - ]) + @parameterized.expand( + [ + param(SwaggerEditor.gen_skeleton()), + # Dict can contain any other unrecognized properties + param({"swagger": "anyvalue", "paths": {}, "foo": "bar", "baz": "bar"}), + ] + ) def test_must_work_on_valid_values(self, swagger): self.assertTrue(SwaggerEditor.is_valid(swagger)) - @parameterized.expand([ - ({}, "empty dictionary"), - ([1, 2, 3], "array data type"), - ({"paths": {}}, "missing swagger property"), - ({"swagger": "hello"}, "missing paths property"), - ({"swagger": "hello", "paths": [1, 2, 3]}, "array value for paths property"), - ]) + @parameterized.expand( + [ + ({}, "empty dictionary"), + ([1, 2, 3], "array data type"), + ({"paths": {}}, "missing swagger property"), + ({"swagger": "hello"}, "missing paths property"), + ({"swagger": "hello", "paths": [1, 2, 3]}, "array value for paths property"), + ] + ) def test_must_fail_for_invalid_values(self, data, case): self.assertFalse(SwaggerEditor.is_valid(data), "Swagger dictionary with {} must not be valid".format(case)) -class TestSwaggerEditor_add_models(TestCase): +class TestSwaggerEditor_add_models(TestCase): def setUp(self): - self.original_swagger = { - "swagger": "2.0", - "paths": { - "/foo": {} - } - } + self.original_swagger = {"swagger": "2.0", "paths": {"/foo": {}}} self.editor = SwaggerEditor(self.original_swagger) def test_must_add_definitions(self): - models = { - 'User': { - 'type': 'object', - 'properties': { - 'username': { - 'type': 'string' - } - } - } - } + models = {"User": {"type": "object", "properties": {"username": {"type": "string"}}}} self.editor.add_models(models) - expected = { - 'user': { - 'type': 'object', - 'properties': { - 'username': { - 'type': 'string' - } - } - } - } + expected = {"user": {"type": "object", "properties": {"username": {"type": "string"}}}} - self.assertEqual(expected, self.editor.swagger['definitions']) + self.assertEqual(expected, self.editor.swagger["definitions"]) def test_must_fail_without_type_in_model(self): - models = { - 'User': { - 'properties': { - 'username': { - 'type': 'string' - } - } - } - } + models = {"User": {"properties": {"username": {"type": "string"}}}} with self.assertRaises(ValueError): self.editor.add_models(models) def test_must_fail_without_properties_in_model(self): - models = { - 'User': { - 'type': 'object' - } - } + models = {"User": {"type": "object"}} with self.assertRaises(ValueError): self.editor.add_models(models) -class TestSwaggerEditor_add_request_model_to_method(TestCase): +class TestSwaggerEditor_add_request_model_to_method(TestCase): def setUp(self): self.original_swagger = { "swagger": "2.0", - "paths": { - "/foo": { - 'get': { - 'x-amazon-apigateway-integration': { - 'test': 'must have integration' - } - } - } - } + "paths": {"/foo": {"get": {"x-amazon-apigateway-integration": {"test": "must have integration"}}}}, } self.editor = SwaggerEditor(self.original_swagger) def test_must_add_body_parameter_to_method_with_required_true(self): - model = { - 'Model': 'User', - 'Required': True - } + model = {"Model": "User", "Required": True} - self.editor.add_request_model_to_method('/foo', 'get', model) + self.editor.add_request_model_to_method("/foo", "get", model) - expected = [ - { - 'in': 'body', - 'required': True, - 'name': 'user', - 'schema': { - '$ref': '#/definitions/user' - } - } - ] + expected = [{"in": "body", "required": True, "name": "user", "schema": {"$ref": "#/definitions/user"}}] - self.assertEqual(expected, self.editor.swagger['paths']['/foo']['get']['parameters']) + self.assertEqual(expected, self.editor.swagger["paths"]["/foo"]["get"]["parameters"]) def test_must_add_body_parameter_to_method_with_required_false(self): - model = { - 'Model': 'User', - 'Required': False - } + model = {"Model": "User", "Required": False} - self.editor.add_request_model_to_method('/foo', 'get', model) + self.editor.add_request_model_to_method("/foo", "get", model) - expected = [ - { - 'in': 'body', - 'required': False, - 'name': 'user', - 'schema': { - '$ref': '#/definitions/user' - } - } - ] + expected = [{"in": "body", "required": False, "name": "user", "schema": {"$ref": "#/definitions/user"}}] - self.assertEqual(expected, self.editor.swagger['paths']['/foo']['get']['parameters']) + self.assertEqual(expected, self.editor.swagger["paths"]["/foo"]["get"]["parameters"]) def test_must_add_body_parameter_to_existing_method_parameters(self): @@ -968,208 +749,114 @@ def test_must_add_body_parameter_to_existing_method_parameters(self): "swagger": "2.0", "paths": { "/foo": { - 'get': { - 'x-amazon-apigateway-integration': { - 'test': 'must have integration' - }, - 'parameters': [{'test': 'existing parameter'}] + "get": { + "x-amazon-apigateway-integration": {"test": "must have integration"}, + "parameters": [{"test": "existing parameter"}], } } - } + }, } editor = SwaggerEditor(original_swagger) - model = { - 'Model': 'User', - 'Required': True - } + model = {"Model": "User", "Required": True} - editor.add_request_model_to_method('/foo', 'get', model) + editor.add_request_model_to_method("/foo", "get", model) expected = [ - { - 'test': 'existing parameter' - }, - { - 'in': 'body', - 'required': True, - 'name': 'user', - 'schema': { - '$ref': '#/definitions/user' - } - } + {"test": "existing parameter"}, + {"in": "body", "required": True, "name": "user", "schema": {"$ref": "#/definitions/user"}}, ] - self.assertEqual(expected, editor.swagger['paths']['/foo']['get']['parameters']) + self.assertEqual(expected, editor.swagger["paths"]["/foo"]["get"]["parameters"]) def test_must_not_add_body_parameter_to_method_without_integration(self): - original_swagger = { - "swagger": "2.0", - "paths": { - "/foo": { - 'get': {} - } - } - } + original_swagger = {"swagger": "2.0", "paths": {"/foo": {"get": {}}}} editor = SwaggerEditor(original_swagger) - model = { - 'Model': 'User', - 'Required': True - } + model = {"Model": "User", "Required": True} - editor.add_request_model_to_method('/foo', 'get', model) + editor.add_request_model_to_method("/foo", "get", model) expected = {} - self.assertEqual(expected, editor.swagger['paths']['/foo']['get']) + self.assertEqual(expected, editor.swagger["paths"]["/foo"]["get"]) def test_must_add_body_parameter_to_method_without_required(self): - model = { - 'Model': 'User' - } + model = {"Model": "User"} - self.editor.add_request_model_to_method('/foo', 'get', model) + self.editor.add_request_model_to_method("/foo", "get", model) - expected = [ - { - 'in': 'body', - 'name': 'user', - 'schema': { - '$ref': '#/definitions/user' - } - } - ] + expected = [{"in": "body", "name": "user", "schema": {"$ref": "#/definitions/user"}}] - self.assertEqual(expected, self.editor.swagger['paths']['/foo']['get']['parameters']) + self.assertEqual(expected, self.editor.swagger["paths"]["/foo"]["get"]["parameters"]) def test_must_add_body_parameter_to_method_openapi_without_required(self): original_openapi = { "openapi": "3.0.1", - "paths": { - "/foo": { - 'get': { - 'x-amazon-apigateway-integration': { - 'test': 'must have integration' - } - } - } - } + "paths": {"/foo": {"get": {"x-amazon-apigateway-integration": {"test": "must have integration"}}}}, } editor = SwaggerEditor(original_openapi) - model = { - 'Model': 'User', - 'Required': True - } + model = {"Model": "User", "Required": True} - editor.add_request_model_to_method('/foo', 'get', model) + editor.add_request_model_to_method("/foo", "get", model) expected = { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/user' - } - } - }, - 'required': True + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/user"}}}, + "required": True, } - self.assertEqual(expected, editor.swagger['paths']['/foo']['get']['requestBody']) + self.assertEqual(expected, editor.swagger["paths"]["/foo"]["get"]["requestBody"]) def test_must_add_body_parameter_to_method_openapi_required_true(self): original_openapi = { "openapi": "3.0.1", - "paths": { - "/foo": { - 'get': { - 'x-amazon-apigateway-integration': { - 'test': 'must have integration' - } - } - } - } + "paths": {"/foo": {"get": {"x-amazon-apigateway-integration": {"test": "must have integration"}}}}, } editor = SwaggerEditor(original_openapi) - model = { - 'Model': 'User' - } + model = {"Model": "User"} - editor.add_request_model_to_method('/foo', 'get', model) + editor.add_request_model_to_method("/foo", "get", model) - expected = { - 'content': { - 'application/json': { - 'schema': { - '$ref': '#/components/schemas/user' - } - } - } - } + expected = {"content": {"application/json": {"schema": {"$ref": "#/components/schemas/user"}}}} - self.assertEqual(expected, editor.swagger['paths']['/foo']['get']['requestBody']) + self.assertEqual(expected, editor.swagger["paths"]["/foo"]["get"]["requestBody"]) class TestSwaggerEditor_add_auth(TestCase): - def setUp(self): self.original_swagger = { "swagger": "2.0", "paths": { - "/foo": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - }, - "post":{ - _X_INTEGRATION: { - "a": "b" - } - } - }, - "/bar": { - "get": { - _X_INTEGRATION: { - "a": "b" - } - } - }, - } + "/foo": {"get": {_X_INTEGRATION: {"a": "b"}}, "post": {_X_INTEGRATION: {"a": "b"}}}, + "/bar": {"get": {_X_INTEGRATION: {"a": "b"}}}, + }, } self.editor = SwaggerEditor(self.original_swagger) def test_add_apikey_security_definition_is_added(self): - expected = { - "type": "apiKey", - "name": "x-api-key", - "in": "header" - } + expected = {"type": "apiKey", "name": "x-api-key", "in": "header"} self.editor.add_apikey_security_definition() - self.assertIn('securityDefinitions', self.editor.swagger) - self.assertIn('api_key', self.editor.swagger["securityDefinitions"]) - self.assertEqual(expected, self.editor.swagger["securityDefinitions"]['api_key']) + self.assertIn("securityDefinitions", self.editor.swagger) + self.assertIn("api_key", self.editor.swagger["securityDefinitions"]) + self.assertEqual(expected, self.editor.swagger["securityDefinitions"]["api_key"]) def test_must_add_default_apikey_to_all_paths(self): - expected = [{ - "api_key": [] - }] + expected = [{"api_key": []}] path = "/foo" - self.editor.set_path_default_apikey_required(path) methods = self.editor.swagger["paths"][path] for method in methods: @@ -1180,46 +867,24 @@ def test_add_default_apikey_to_all_paths_correctly_handles_method_level_settings "swagger": "2.0", "paths": { "/foo": { - "apikeyfalse": { - _X_INTEGRATION: { - "a": "b" - }, - "security":[ - {"api_key_false":[]} - ] - }, - "apikeytrue": { - _X_INTEGRATION: { - "a": "b" - }, - "security":[ - {"api_key":[]} - ] - }, - "apikeydefault":{ - _X_INTEGRATION: { - "a": "b" - } - } - }, - } + "apikeyfalse": {_X_INTEGRATION: {"a": "b"}, "security": [{"api_key_false": []}]}, + "apikeytrue": {_X_INTEGRATION: {"a": "b"}, "security": [{"api_key": []}]}, + "apikeydefault": {_X_INTEGRATION: {"a": "b"}}, + } + }, } self.editor = SwaggerEditor(self.original_swagger) - api_key_exists = [{ - "api_key": [] - }] + api_key_exists = [{"api_key": []}] path = "/foo" self.editor.set_path_default_apikey_required(path) - self.assertEqual([], self.editor.swagger["paths"][path]['apikeyfalse']["security"]) - self.assertEqual(api_key_exists, self.editor.swagger["paths"][path]['apikeytrue']["security"]) - self.assertEqual(api_key_exists, self.editor.swagger["paths"][path]['apikeydefault']["security"]) + self.assertEqual([], self.editor.swagger["paths"][path]["apikeyfalse"]["security"]) + self.assertEqual(api_key_exists, self.editor.swagger["paths"][path]["apikeytrue"]["security"]) + self.assertEqual(api_key_exists, self.editor.swagger["paths"][path]["apikeydefault"]["security"]) def test_set_method_apikey_handling_apikeyrequired_false(self): - expected = [{ - "api_key_false": [] - }] + expected = [{"api_key_false": []}] path = "/bar" method = "get" @@ -1227,9 +892,7 @@ def test_set_method_apikey_handling_apikeyrequired_false(self): self.assertEqual(expected, self.editor.swagger["paths"][path][method]["security"]) def test_set_method_apikey_handling_apikeyrequired_true(self): - expected = [{ - "api_key": [] - }] + expected = [{"api_key": []}] path = "/bar" method = "get" @@ -1238,70 +901,39 @@ def test_set_method_apikey_handling_apikeyrequired_true(self): class TestSwaggerEditor_add_request_parameter_to_method(TestCase): - def setUp(self): self.original_swagger = { "swagger": "2.0", - "paths": { - "/foo": { - 'get': { - 'x-amazon-apigateway-integration': { - 'test': 'must have integration' - } - } - } - } + "paths": {"/foo": {"get": {"x-amazon-apigateway-integration": {"test": "must have integration"}}}}, } self.editor = SwaggerEditor(self.original_swagger) def test_must_add_parameter_to_method_with_required_and_caching_true(self): - parameters = [{ - 'Name': 'method.request.header.Authorization', - 'Required': True, - 'Caching': True - }] + parameters = [{"Name": "method.request.header.Authorization", "Required": True, "Caching": True}] - self.editor.add_request_parameters_to_method('/foo', 'get', parameters) + self.editor.add_request_parameters_to_method("/foo", "get", parameters) - expected_parameters = [ - { - 'in': 'header', - 'required': True, - 'name': 'Authorization', - 'type': 'string' - } - ] + expected_parameters = [{"in": "header", "required": True, "name": "Authorization", "type": "string"}] - method_swagger = self.editor.swagger['paths']['/foo']['get'] + method_swagger = self.editor.swagger["paths"]["/foo"]["get"] - self.assertEqual(expected_parameters, method_swagger['parameters']) - self.assertEqual(['method.request.header.Authorization'], method_swagger[_X_INTEGRATION]['cacheKeyParameters']) + self.assertEqual(expected_parameters, method_swagger["parameters"]) + self.assertEqual(["method.request.header.Authorization"], method_swagger[_X_INTEGRATION]["cacheKeyParameters"]) def test_must_add_parameter_to_method_with_required_and_caching_false(self): - parameters = [{ - 'Name': 'method.request.header.Authorization', - 'Required': False, - 'Caching': False - }] + parameters = [{"Name": "method.request.header.Authorization", "Required": False, "Caching": False}] - self.editor.add_request_parameters_to_method('/foo', 'get', parameters) + self.editor.add_request_parameters_to_method("/foo", "get", parameters) - expected_parameters = [ - { - 'in': 'header', - 'required': False, - 'name': 'Authorization', - 'type': 'string' - } - ] + expected_parameters = [{"in": "header", "required": False, "name": "Authorization", "type": "string"}] - method_swagger = self.editor.swagger['paths']['/foo']['get'] + method_swagger = self.editor.swagger["paths"]["/foo"]["get"] - self.assertEqual(expected_parameters, method_swagger['parameters']) - self.assertNotIn('cacheKeyParameters', method_swagger[_X_INTEGRATION].keys()) + self.assertEqual(expected_parameters, method_swagger["parameters"]) + self.assertNotIn("cacheKeyParameters", method_swagger[_X_INTEGRATION].keys()) def test_must_add_parameter_to_method_with_existing_parameters(self): @@ -1309,94 +941,58 @@ def test_must_add_parameter_to_method_with_existing_parameters(self): "swagger": "2.0", "paths": { "/foo": { - 'get': { - 'x-amazon-apigateway-integration': { - 'test': 'must have integration' - }, - 'parameters': [{'test': 'existing parameter'}] + "get": { + "x-amazon-apigateway-integration": {"test": "must have integration"}, + "parameters": [{"test": "existing parameter"}], } } - } + }, } editor = SwaggerEditor(original_swagger) - parameters = [{ - 'Name': 'method.request.header.Authorization', - 'Required': False, - 'Caching': False - }] + parameters = [{"Name": "method.request.header.Authorization", "Required": False, "Caching": False}] - editor.add_request_parameters_to_method('/foo', 'get', parameters) + editor.add_request_parameters_to_method("/foo", "get", parameters) expected_parameters = [ - { - 'test': 'existing parameter' - }, - { - 'in': 'header', - 'required': False, - 'name': 'Authorization', - 'type': 'string' - } + {"test": "existing parameter"}, + {"in": "header", "required": False, "name": "Authorization", "type": "string"}, ] - method_swagger = editor.swagger['paths']['/foo']['get'] + method_swagger = editor.swagger["paths"]["/foo"]["get"] - self.assertEqual(expected_parameters, method_swagger['parameters']) - self.assertNotIn('cacheKeyParameters', method_swagger[_X_INTEGRATION].keys()) + self.assertEqual(expected_parameters, method_swagger["parameters"]) + self.assertNotIn("cacheKeyParameters", method_swagger[_X_INTEGRATION].keys()) def test_must_not_add_parameter_to_method_without_integration(self): - original_swagger = { - "swagger": "2.0", - "paths": { - "/foo": { - 'get': {} - } - } - } + original_swagger = {"swagger": "2.0", "paths": {"/foo": {"get": {}}}} editor = SwaggerEditor(original_swagger) - parameters = [{ - 'Name': 'method.request.header.Authorization', - 'Required': True, - 'Caching': True - }] + parameters = [{"Name": "method.request.header.Authorization", "Required": True, "Caching": True}] - editor.add_request_parameters_to_method('/foo', 'get', parameters) + editor.add_request_parameters_to_method("/foo", "get", parameters) expected = {} - self.assertEqual(expected, editor.swagger['paths']['/foo']['get']) + self.assertEqual(expected, editor.swagger["paths"]["/foo"]["get"]) class TestSwaggerEditor_add_resource_policy(TestCase): def setUp(self): - self.original_swagger = { - "swagger": "2.0", - "paths": { - "/foo": { - "get": {}, - "put": {} - } - } - } + self.original_swagger = {"swagger": "2.0", "paths": {"/foo": {"get": {}, "put": {}}}} self.editor = SwaggerEditor(self.original_swagger) def test_must_add_custom_statements(self): resourcePolicy = { - 'CustomStatements': [{ - 'Action': 'execute-api:Invoke', - 'Resource': ['execute-api:/*/*/*'] - }, - { - 'Action': 'execute-api:blah', - 'Resource': ['execute-api:/*/*/*'] - }] + "CustomStatements": [ + {"Action": "execute-api:Invoke", "Resource": ["execute-api:/*/*/*"]}, + {"Action": "execute-api:blah", "Resource": ["execute-api:/*/*/*"]}, + ] } self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") @@ -1404,456 +1000,280 @@ def test_must_add_custom_statements(self): expected = { "Version": "2012-10-17", "Statement": [ - { - "Action": "execute-api:Invoke", - "Resource": [ - "execute-api:/*/*/*" - ] - }, - { - "Action": "execute-api:blah", - "Resource": [ - "execute-api:/*/*/*" - ] - } - ] + {"Action": "execute-api:Invoke", "Resource": ["execute-api:/*/*/*"]}, + {"Action": "execute-api:blah", "Resource": ["execute-api:/*/*/*"]}, + ], } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_iam_allow(self): -## fails - resourcePolicy = { - 'AwsAccountWhitelist': [ - '123456' - ] - } + ## fails + resourcePolicy = {"AwsAccountWhitelist": ["123456"]} self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Allow', - 'Principal': { - 'AWS': ['123456'] - } - } + "Version": "2012-10-17", + "Statement": { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Allow", + "Principal": {"AWS": ["123456"]}, + }, } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_iam_deny(self): - resourcePolicy = { - 'AwsAccountBlacklist': [ - '123456' - ] - } + resourcePolicy = {"AwsAccountBlacklist": ["123456"]} self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Deny', - 'Principal': { - 'AWS': ['123456'] - } - } + "Version": "2012-10-17", + "Statement": { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Deny", + "Principal": {"AWS": ["123456"]}, + }, } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_ip_allow(self): - resourcePolicy = { - 'IpRangeWhitelist': [ - '1.2.3.4' - ] - } + resourcePolicy = {"IpRangeWhitelist": ["1.2.3.4"]} self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Allow', - 'Principal': '*' - }, - { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Deny', - 'Condition': { - 'NotIpAddress': { - 'aws:SourceIp': ['1.2.3.4'] - } + "Version": "2012-10-17", + "Statement": [ + { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Allow", + "Principal": "*", }, - 'Principal': '*' - }] + { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Deny", + "Condition": {"NotIpAddress": {"aws:SourceIp": ["1.2.3.4"]}}, + "Principal": "*", + }, + ], } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_ip_deny(self): - resourcePolicy = { - 'IpRangeBlacklist': [ - '1.2.3.4' - ] - } + resourcePolicy = {"IpRangeBlacklist": ["1.2.3.4"]} self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Allow', - 'Principal': '*' - }, - { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Deny', - 'Condition': { - 'IpAddress': { - 'aws:SourceIp': ['1.2.3.4'] - } + "Version": "2012-10-17", + "Statement": [ + { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Allow", + "Principal": "*", + }, + { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Deny", + "Condition": {"IpAddress": {"aws:SourceIp": ["1.2.3.4"]}}, + "Principal": "*", }, - 'Principal': '*' - }] + ], } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_vpc_allow(self): - resourcePolicy = { - 'SourceVpcWhitelist': [ - 'vpc-123', - 'vpce-345' - ] - } + resourcePolicy = {"SourceVpcWhitelist": ["vpc-123", "vpce-345"]} self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': [ + "Version": "2012-10-17", + "Statement": [ { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Allow', - 'Principal': '*' + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Allow", + "Principal": "*", }, { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Deny', - 'Condition': { - 'StringNotEquals': { - 'aws:SourceVpc': 'vpc-123' - } - }, - 'Principal': '*' + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Deny", + "Condition": {"StringNotEquals": {"aws:SourceVpc": "vpc-123"}}, + "Principal": "*", }, { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Deny', - 'Condition': { - 'StringNotEquals': { - 'aws:SourceVpce': 'vpce-345' - } - }, - 'Principal': '*' - } - ] + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Deny", + "Condition": {"StringNotEquals": {"aws:SourceVpce": "vpce-345"}}, + "Principal": "*", + }, + ], } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_vpc_deny(self): - resourcePolicy = { - 'SourceVpcBlacklist': [ - 'vpc-123' - ] - } + resourcePolicy = {"SourceVpcBlacklist": ["vpc-123"]} self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': [ + "Version": "2012-10-17", + "Statement": [ { - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Allow', - 'Principal': '*' + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Allow", + "Principal": "*", }, { - 'Action': 'execute-api:Invoke', - 'Resource': [ { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Deny', - 'Condition': { - 'StringEquals': { - 'aws:SourceVpc': 'vpc-123' - } - }, - 'Principal': '*' - } - ] + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Deny", + "Condition": {"StringEquals": {"aws:SourceVpc": "vpc-123"}}, + "Principal": "*", + }, + ], } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) def test_must_add_iam_allow_and_custom(self): resourcePolicy = { - 'AwsAccountWhitelist': [ - '123456' + "AwsAccountWhitelist": ["123456"], + "CustomStatements": [ + {"Action": "execute-api:Invoke", "Resource": ["execute-api:/*/*/*"]}, + {"Action": "execute-api:blah", "Resource": ["execute-api:/*/*/*"]}, ], - 'CustomStatements': [{ - 'Action': 'execute-api:Invoke', - 'Resource': ['execute-api:/*/*/*'] - }, - { - 'Action': 'execute-api:blah', - 'Resource': ['execute-api:/*/*/*'] - }] } self.editor.add_resource_policy(resourcePolicy, "/foo", "123", "prod") expected = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': 'execute-api:Invoke', - 'Resource': [{ - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/PUT/foo', - {'__Stage__': 'prod'} - ] - }, - { - 'Fn::Sub': [ - 'execute-api:/${__Stage__}/GET/foo', - {'__Stage__': 'prod'} - ] - }], - 'Effect': 'Allow', - 'Principal': { - 'AWS': ['123456'] - } - }, - { - "Action": "execute-api:Invoke", - "Resource": [ - "execute-api:/*/*/*" - ] - }, - { - "Action": "execute-api:blah", - "Resource": [ - "execute-api:/*/*/*" - ] - }] + "Version": "2012-10-17", + "Statement": [ + { + "Action": "execute-api:Invoke", + "Resource": [ + {"Fn::Sub": ["execute-api:/${__Stage__}/PUT/foo", {"__Stage__": "prod"}]}, + {"Fn::Sub": ["execute-api:/${__Stage__}/GET/foo", {"__Stage__": "prod"}]}, + ], + "Effect": "Allow", + "Principal": {"AWS": ["123456"]}, + }, + {"Action": "execute-api:Invoke", "Resource": ["execute-api:/*/*/*"]}, + {"Action": "execute-api:blah", "Resource": ["execute-api:/*/*/*"]}, + ], } self.assertEqual(deep_sort_lists(expected), deep_sort_lists(self.editor.swagger[_X_POLICY])) + class TestSwaggerEditor_add_authorization_scopes(TestCase): def setUp(self): self.api = api = { - 'Auth':{ - 'Authorizers': { - 'MyOtherCognitoAuth':{}, - 'MyCognitoAuth': {} - }, - 'DefaultAuthorizer': "MyCognitoAuth" + "Auth": { + "Authorizers": {"MyOtherCognitoAuth": {}, "MyCognitoAuth": {}}, + "DefaultAuthorizer": "MyCognitoAuth", } } - self.editor = SwaggerEditor({ - "swagger": "2.0", - "paths": { - "/cognito": { - "get": { - "x-amazon-apigateway-integration": { - "httpMethod": "POST", - "type": "aws_proxy", - "uri": { - "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyFn.Arn}/invocations" + self.editor = SwaggerEditor( + { + "swagger": "2.0", + "paths": { + "/cognito": { + "get": { + "x-amazon-apigateway-integration": { + "httpMethod": "POST", + "type": "aws_proxy", + "uri": { + "Fn::Sub": "arn:aws:apigateway:${AWS::Region}:lambda:path/2015-03-31/functions/${MyFn.Arn}/invocations" + }, + }, + "security": [], + "responses": {}, + } } - }, - "security": [], - "responses": {} - } }, } - }) + ) def test_should_include_auth_scopes_if_defined_with_authorizer(self): - auth = { - 'AuthorizationScopes': ["ResourceName/method.scope"], - 'Authorizer':"MyOtherCognitoAuth" - } + auth = {"AuthorizationScopes": ["ResourceName/method.scope"], "Authorizer": "MyOtherCognitoAuth"} self.editor.add_auth_to_method("/cognito", "get", auth, self.api) - self.assertEqual([{"MyOtherCognitoAuth": ["ResourceName/method.scope"]}], - self.editor.swagger["paths"]["/cognito"]["get"]["security"]) + self.assertEqual( + [{"MyOtherCognitoAuth": ["ResourceName/method.scope"]}], + self.editor.swagger["paths"]["/cognito"]["get"]["security"], + ) def test_should_include_auth_scopes_with_default_authorizer(self): - auth = { - 'AuthorizationScopes': ["ResourceName/method.scope"], - 'Authorizer': 'MyCognitoAuth' - } + auth = {"AuthorizationScopes": ["ResourceName/method.scope"], "Authorizer": "MyCognitoAuth"} self.editor.add_auth_to_method("/cognito", "get", auth, self.api) - self.assertEqual([{"MyCognitoAuth": ["ResourceName/method.scope"]}], - self.editor.swagger["paths"]["/cognito"]["get"]["security"]) + self.assertEqual( + [{"MyCognitoAuth": ["ResourceName/method.scope"]}], + self.editor.swagger["paths"]["/cognito"]["get"]["security"], + ) def test_should_include_only_specified_authorizer_auth_if_no_scopes_defined(self): - auth = { - 'Authorizer':"MyOtherCognitoAuth" - } + auth = {"Authorizer": "MyOtherCognitoAuth"} self.editor.add_auth_to_method("/cognito", "get", auth, self.api) - self.assertEqual([{"MyOtherCognitoAuth": []}], - self.editor.swagger["paths"]["/cognito"]["get"]["security"]) + self.assertEqual([{"MyOtherCognitoAuth": []}], self.editor.swagger["paths"]["/cognito"]["get"]["security"]) def test_should_include_none_if_default_is_overwritte(self): - auth = { - 'Authorizer':"NONE" - } + auth = {"Authorizer": "NONE"} self.editor.add_auth_to_method("/cognito", "get", auth, self.api) - self.assertEqual([{"NONE": []}], - self.editor.swagger["paths"]["/cognito"]["get"]["security"]) - + self.assertEqual([{"NONE": []}], self.editor.swagger["paths"]["/cognito"]["get"]["security"]) diff --git a/tests/test_intrinsics.py b/tests/test_intrinsics.py index 68ca9a97d9..af094f7e2d 100644 --- a/tests/test_intrinsics.py +++ b/tests/test_intrinsics.py @@ -3,20 +3,12 @@ from samtranslator.model.intrinsics import is_instrinsic, make_shorthand -class TestIntrinsics(TestCase): - @parameterized.expand([ - "Ref", - "Condition", - "Fn::foo", - "Fn::sub", - "Fn::something" - ]) +class TestIntrinsics(TestCase): + @parameterized.expand(["Ref", "Condition", "Fn::foo", "Fn::sub", "Fn::something"]) def test_is_intrinsic_must_detect_intrinsics(self, intrinsic_name): - input = { - intrinsic_name: ["some value"] - } + input = {intrinsic_name: ["some value"]} self.assertTrue(is_instrinsic(input)) @@ -24,23 +16,17 @@ def test_is_intrinsic_on_empty_input(self): self.assertFalse(is_instrinsic(None)) def test_is_intrinsic_on_non_dict_input(self): - self.assertFalse(is_instrinsic([1,2,3])) + self.assertFalse(is_instrinsic([1, 2, 3])) def test_is_intrinsic_on_intrinsic_like_dict_input(self): - self.assertFalse(is_instrinsic({ - "Ref": "foo", - "key": "bar" - })) - - @parameterized.expand([ - ({"Ref": "foo"}, "${foo}"), - ({"Fn::GetAtt": ["foo", "Arn"]}, "${foo.Arn}") - ]) + self.assertFalse(is_instrinsic({"Ref": "foo", "key": "bar"})) + + @parameterized.expand([({"Ref": "foo"}, "${foo}"), ({"Fn::GetAtt": ["foo", "Arn"]}, "${foo.Arn}")]) def test_make_shorthand_success(self, input, expected): self.assertEqual(make_shorthand(input), expected) def test_make_short_hand_failure(self): - input = { "Fn::Sub": "something" } + input = {"Fn::Sub": "something"} with self.assertRaises(NotImplementedError): make_shorthand(input) diff --git a/tests/test_model.py b/tests/test_model.py index 089b899e03..446e5b5759 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,6 +7,7 @@ from samtranslator.intrinsics.resource_refs import SupportedResourceReferences from samtranslator.plugins import LifeCycleEvents + def valid_if_true(value, should_raise=True): """Validator that passes if the input is True.""" if value is True: @@ -16,53 +17,85 @@ def valid_if_true(value, should_raise=True): raise TypeError return False + class DummyResource(Resource): - resource_type = 'AWS::Dummy::Resource' + resource_type = "AWS::Dummy::Resource" property_types = { - 'RequiredProperty': PropertyType(True, valid_if_true), - 'OptionalProperty': PropertyType(False, valid_if_true) + "RequiredProperty": PropertyType(True, valid_if_true), + "OptionalProperty": PropertyType(False, valid_if_true), } -@pytest.mark.parametrize('logical_id,resource_dict,expected_exception', [ - # Valid required property - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'RequiredProperty': True}}, None), - # Valid required property and valid optional property - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'RequiredProperty': True, 'OptionalProperty': True}}, None), - # Required property not provided - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'OptionalProperty': True}}, InvalidResourceException), - # Required property provided, but invalid - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'RequiredProperty': False}}, InvalidResourceException), - # Property with invalid name provided - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'RequiredProperty': True, 'InvalidProperty': True}}, InvalidResourceException), - # Missing Properties - ('id', {'Type': 'AWS::Other::Other'}, InvalidResourceException), - # Missing Type - ('id', {'Properties': {'RequiredProperty': True, 'OptionalProperty': True}}, InvalidResourceException), - # Valid Type with invalid Properties - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': 'Invalid'}, InvalidResourceException), - # Valid Properties with invalid Type - ('id', {'Type': 'AWS::Invalid::Invalid', 'Properties': {'RequiredProperty': True, 'OptionalProperty': True}}, InvalidResourceException), - # Invalid logical_id - ('invalid_id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'RequiredProperty': True, 'OptionalProperty': True}}, InvalidResourceException), - # intrinsic function - ('id', {'Type': 'AWS::Dummy::Resource', 'Properties': {'RequiredProperty': {'Fn::Any': ['logicalid', 'Arn']}}}, None) -]) + +@pytest.mark.parametrize( + "logical_id,resource_dict,expected_exception", + [ + # Valid required property + ("id", {"Type": "AWS::Dummy::Resource", "Properties": {"RequiredProperty": True}}, None), + # Valid required property and valid optional property + ( + "id", + {"Type": "AWS::Dummy::Resource", "Properties": {"RequiredProperty": True, "OptionalProperty": True}}, + None, + ), + # Required property not provided + ("id", {"Type": "AWS::Dummy::Resource", "Properties": {"OptionalProperty": True}}, InvalidResourceException), + # Required property provided, but invalid + ("id", {"Type": "AWS::Dummy::Resource", "Properties": {"RequiredProperty": False}}, InvalidResourceException), + # Property with invalid name provided + ( + "id", + {"Type": "AWS::Dummy::Resource", "Properties": {"RequiredProperty": True, "InvalidProperty": True}}, + InvalidResourceException, + ), + # Missing Properties + ("id", {"Type": "AWS::Other::Other"}, InvalidResourceException), + # Missing Type + ("id", {"Properties": {"RequiredProperty": True, "OptionalProperty": True}}, InvalidResourceException), + # Valid Type with invalid Properties + ("id", {"Type": "AWS::Dummy::Resource", "Properties": "Invalid"}, InvalidResourceException), + # Valid Properties with invalid Type + ( + "id", + {"Type": "AWS::Invalid::Invalid", "Properties": {"RequiredProperty": True, "OptionalProperty": True}}, + InvalidResourceException, + ), + # Invalid logical_id + ( + "invalid_id", + {"Type": "AWS::Dummy::Resource", "Properties": {"RequiredProperty": True, "OptionalProperty": True}}, + InvalidResourceException, + ), + # intrinsic function + ( + "id", + {"Type": "AWS::Dummy::Resource", "Properties": {"RequiredProperty": {"Fn::Any": ["logicalid", "Arn"]}}}, + None, + ), + ], +) def test_resource_type_validation(logical_id, resource_dict, expected_exception): if not expected_exception: resource = DummyResource.from_dict(logical_id, resource_dict) - for name, value in resource_dict['Properties'].items(): - assert getattr(resource, name) == value, "resource did not have expected property attribute {property_name} with value {property_value}".format(property_name=name, property_value=value) + for name, value in resource_dict["Properties"].items(): + assert ( + getattr(resource, name) == value + ), "resource did not have expected property attribute {property_name} with value {property_value}".format( + property_name=name, property_value=value + ) actual_to_dict = resource.to_dict() - expected_to_dict = {'id': resource_dict} - assert actual_to_dict == expected_to_dict, "to_dict() returned different values from what was passed to from_dict(); expected {expected}, got {actual}".format(expected=expected_to_dict, actual=actual_to_dict) + expected_to_dict = {"id": resource_dict} + assert ( + actual_to_dict == expected_to_dict + ), "to_dict() returned different values from what was passed to from_dict(); expected {expected}, got {actual}".format( + expected=expected_to_dict, actual=actual_to_dict + ) else: with pytest.raises(expected_exception): resource = DummyResource.from_dict(logical_id, resource_dict) class TestResourceAttributes(TestCase): - class MyResource(Resource): resource_type = "foo" property_types = {} @@ -72,7 +105,9 @@ def test_to_dict(self): """ empty_resource_dict = {"id": {"Type": "foo", "Properties": {}}} - dict_with_attributes = {"id": {"Type": "foo", "Properties": {}, "UpdatePolicy": "update", "DeletionPolicy": {"foo": "bar"}}} + dict_with_attributes = { + "id": {"Type": "foo", "Properties": {}, "UpdatePolicy": "update", "DeletionPolicy": {"foo": "bar"}} + } r = self.MyResource("id") self.assertEqual(r.to_dict(), empty_resource_dict) @@ -87,7 +122,12 @@ def test_invalid_attr(self): self.MyResource("id", attributes={"foo": "bar"}) # But unsupported properties will silently be ignored when deserialization from dictionary - with_unsupported_attributes = {"Type": "foo", "Properties": {}, "DeletionPolicy": "foo", "UnsupportedPolicy": "bar"} + with_unsupported_attributes = { + "Type": "foo", + "Properties": {}, + "DeletionPolicy": "foo", + "UnsupportedPolicy": "bar", + } r = self.MyResource.from_dict("id", resource_dict=with_unsupported_attributes) self.assertEqual(r.get_resource_attribute("DeletionPolicy"), "foo") @@ -96,26 +136,27 @@ def test_invalid_attr(self): def test_from_dict(self): no_attribute = {"Type": "foo", "Properties": {}} - all_supported_attributes = {"Type": "foo", "Properties": {}, "UpdatePolicy": "update", "DeletionPolicy": [1,2,3]} + all_supported_attributes = { + "Type": "foo", + "Properties": {}, + "UpdatePolicy": "update", + "DeletionPolicy": [1, 2, 3], + } r = self.MyResource.from_dict("id", resource_dict=no_attribute) - self.assertEqual(r.logical_id, "id") # Just making sure the resource got created + self.assertEqual(r.logical_id, "id") # Just making sure the resource got created r = self.MyResource.from_dict("id", resource_dict=all_supported_attributes) - self.assertEqual(r.get_resource_attribute("DeletionPolicy"), [1,2,3]) + self.assertEqual(r.get_resource_attribute("DeletionPolicy"), [1, 2, 3]) self.assertEqual(r.get_resource_attribute("UpdatePolicy"), "update") -class TestResourceRuntimeAttributes(TestCase): +class TestResourceRuntimeAttributes(TestCase): def test_resource_must_override_runtime_attributes(self): - class NewResource(Resource): resource_type = "foo" property_types = {} - runtime_attrs = { - "attr1": Mock(), - "attr2": Mock() - } + runtime_attrs = {"attr1": Mock(), "attr2": Mock()} runtime_attrs["attr1"].return_value = "value1" runtime_attrs["attr2"].return_value = "value2" @@ -137,8 +178,8 @@ class NewResource(Resource): resource = NewResource("SomeId") self.assertEqual(0, len(resource.runtime_attrs)) -class TestSamResourceReferableProperties(TestCase): +class TestSamResourceReferableProperties(TestCase): class ResourceType1(Resource): resource_type = "resource_type1" property_types = {} @@ -158,19 +199,13 @@ def test_must_get_property_for_available_resources(self): class NewSamResource(SamResourceMacro): resource_type = "foo" property_types = {} - referable_properties = { - "prop1": "resource_type1", - "prop2": "resource_type2", - "prop3": "resource_type3" - } + referable_properties = {"prop1": "resource_type1", "prop2": "resource_type2", "prop3": "resource_type3"} sam_resource = NewSamResource("SamLogicalId") - cfn_resources = [self.ResourceType1("logicalId1"), - self.ResourceType2("logicalId2")] + cfn_resources = [self.ResourceType1("logicalId1"), self.ResourceType2("logicalId2")] - self.supported_resource_refs = \ - sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) + self.supported_resource_refs = sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) self.assertEqual("logicalId1", self.supported_resource_refs.get("SamLogicalId", "prop1")) self.assertEqual("logicalId2", self.supported_resource_refs.get("SamLogicalId", "prop2")) @@ -185,24 +220,20 @@ def test_must_work_with_two_resources_of_same_type(self): class NewSamResource(SamResourceMacro): resource_type = "foo" property_types = {} - referable_properties = { - "prop1": "resource_type1", - "prop2": "resource_type2", - "prop3": "resource_type3" - } + referable_properties = {"prop1": "resource_type1", "prop2": "resource_type2", "prop3": "resource_type3"} sam_resource1 = NewSamResource("SamLogicalId1") sam_resource2 = NewSamResource("SamLogicalId2") - cfn_resources = [self.ResourceType1("logicalId1"), - self.ResourceType2("logicalId2") - ] + cfn_resources = [self.ResourceType1("logicalId1"), self.ResourceType2("logicalId2")] - self.supported_resource_refs = \ - sam_resource1.get_resource_references(cfn_resources, self.supported_resource_refs) + self.supported_resource_refs = sam_resource1.get_resource_references( + cfn_resources, self.supported_resource_refs + ) - self.supported_resource_refs = \ - sam_resource2.get_resource_references(cfn_resources, self.supported_resource_refs) + self.supported_resource_refs = sam_resource2.get_resource_references( + cfn_resources, self.supported_resource_refs + ) self.assertEqual("logicalId1", self.supported_resource_refs.get("SamLogicalId1", "prop1")) self.assertEqual("logicalId2", self.supported_resource_refs.get("SamLogicalId1", "prop2")) @@ -215,19 +246,14 @@ def test_must_skip_unknown_resource_types(self): class NewSamResource(SamResourceMacro): resource_type = "foo" property_types = {} - referable_properties = { - "prop1": "foo", - "prop2": "bar", - } + referable_properties = {"prop1": "foo", "prop2": "bar"} sam_resource = NewSamResource("SamLogicalId") # None of the CFN resource types are in the referable list - cfn_resources = [self.ResourceType1("logicalId1"), - self.ResourceType2("logicalId2")] + cfn_resources = [self.ResourceType1("logicalId1"), self.ResourceType2("logicalId2")] - self.supported_resource_refs = \ - sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) + self.supported_resource_refs = sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) self.assertEqual(0, len(self.supported_resource_refs)) @@ -239,11 +265,9 @@ class NewSamResource(SamResourceMacro): sam_resource = NewSamResource("SamLogicalId") - cfn_resources = [self.ResourceType1("logicalId1"), - self.ResourceType2("logicalId2")] + cfn_resources = [self.ResourceType1("logicalId1"), self.ResourceType2("logicalId2")] - self.supported_resource_refs = \ - sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) + self.supported_resource_refs = sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) self.assertEqual(0, len(self.supported_resource_refs)) @@ -257,8 +281,7 @@ class NewSamResource(SamResourceMacro): cfn_resources = [] - self.supported_resource_refs = \ - sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) + self.supported_resource_refs = sam_resource.get_resource_references(cfn_resources, self.supported_resource_refs) self.assertEqual(0, len(self.supported_resource_refs)) @@ -277,7 +300,6 @@ class NewSamResource(SamResourceMacro): class TestResourceTypeResolver(TestCase): - def test_can_resolve_must_handle_null_resource_dict(self): resolver = ResourceTypeResolver() @@ -295,37 +317,29 @@ def test_can_resolve_must_handle_dict_without_type(self): def test_can_resolve_must_handle_known_types(self): resolver = ResourceTypeResolver() - resolver.resource_types = {"type1": DummyResource("id")} + resolver.resource_types = {"type1": DummyResource("id")} self.assertTrue(resolver.can_resolve({"Type": "type1"})) def test_can_resolve_must_handle_unknown_types(self): resolver = ResourceTypeResolver() - resolver.resource_types = {"type1": DummyResource("id")} + resolver.resource_types = {"type1": DummyResource("id")} self.assertFalse(resolver.can_resolve({"Type": "AWS::Lambda::Function"})) -class TestSamPluginsInResource(TestCase): +class TestSamPluginsInResource(TestCase): def test_must_act_on_plugins_before_resource_creation(self): resource_type = "AWS::Dummy::Resource" - resource_dict = { - "Type": resource_type, - "Properties": { - "RequiredProperty": True - } - } - expected_properties = { - "RequiredProperty": True - } + resource_dict = {"Type": resource_type, "Properties": {"RequiredProperty": True}} + expected_properties = {"RequiredProperty": True} mock_sam_plugins = Mock() - DummyResource.from_dict("logicalId", resource_dict,sam_plugins=mock_sam_plugins) + DummyResource.from_dict("logicalId", resource_dict, sam_plugins=mock_sam_plugins) - mock_sam_plugins.act.assert_called_once_with(LifeCycleEvents.before_transform_resource, - "logicalId", - resource_type, - expected_properties) + mock_sam_plugins.act.assert_called_once_with( + LifeCycleEvents.before_transform_resource, "logicalId", resource_type, expected_properties + ) def test_must_act_on_plugins_for_resource_having_no_properties(self): resource_type = "MyResourceType" @@ -335,16 +349,12 @@ class MyResource(Resource): resource_type = "MyResourceType" # No Properties for this resource - resource_dict = { - "Type": resource_type - } + resource_dict = {"Type": resource_type} expected_properties = {} mock_sam_plugins = Mock() MyResource.from_dict("logicalId", resource_dict, sam_plugins=mock_sam_plugins) - mock_sam_plugins.act.assert_called_once_with(LifeCycleEvents.before_transform_resource, - "logicalId", - resource_type, - expected_properties) - + mock_sam_plugins.act.assert_called_once_with( + LifeCycleEvents.before_transform_resource, "logicalId", resource_type, expected_properties + ) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 37a48f0eb7..a1bfb97ab7 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -4,8 +4,8 @@ from unittest import TestCase from mock import Mock, patch, call -class TestSamPluginsRegistration(TestCase): +class TestSamPluginsRegistration(TestCase): def setUp(self): # Setup the plugin to be a "subclass" of the BasePlugin @@ -15,7 +15,6 @@ def setUp(self): self.sam_plugins = SamPlugins() - def test_register_must_work(self): self.sam_plugins.register(self.mock_plugin) @@ -114,8 +113,8 @@ def test_get_must_handle_non_registered_plugins(self): self.sam_plugins.register(self.mock_plugin) self.assertIsNone(self.sam_plugins._get("some plugin")) -class TestSamPluginsAct(TestCase): +class TestSamPluginsAct(TestCase): def setUp(self): self.sam_plugins = SamPlugins() @@ -131,7 +130,7 @@ def test_act_must_invoke_correct_hook_method(self): # Setup the plugin to return a mock when the "on_" method is invoked plugin = _make_mock_plugin("plugin") hook_method = Mock() - setattr(plugin, "on_"+self.my_event.name, hook_method) + setattr(plugin, "on_" + self.my_event.name, hook_method) self.sam_plugins.register(plugin) @@ -147,9 +146,12 @@ def test_act_must_invoke_correct_hook_method(self): def test_act_must_invoke_hook_on_all_plugins(self): # Create three plugins, and setup hook methods on it - plugin1 = _make_mock_plugin("plugin1"); setattr(plugin1, "on_"+self.my_event.name, Mock()) - plugin2 = _make_mock_plugin("plugin2"); setattr(plugin2, "on_"+self.my_event.name, Mock()) - plugin3 = _make_mock_plugin("plugin3"); setattr(plugin3, "on_"+self.my_event.name, Mock()) + plugin1 = _make_mock_plugin("plugin1") + setattr(plugin1, "on_" + self.my_event.name, Mock()) + plugin2 = _make_mock_plugin("plugin2") + setattr(plugin2, "on_" + self.my_event.name, Mock()) + plugin3 = _make_mock_plugin("plugin3") + setattr(plugin3, "on_" + self.my_event.name, Mock()) self.sam_plugins.register(plugin1) self.sam_plugins.register(plugin2) @@ -169,9 +171,12 @@ def test_act_must_invoke_hook_on_all_plugins(self): def test_act_must_invoke_plugins_in_sequence(self): # Create three plugins, and setup hook methods on it - plugin1 = _make_mock_plugin("plugin1"); setattr(plugin1, "on_"+self.my_event.name, Mock()) - plugin2 = _make_mock_plugin("plugin2"); setattr(plugin2, "on_"+self.my_event.name, Mock()) - plugin3 = _make_mock_plugin("plugin3"); setattr(plugin3, "on_"+self.my_event.name, Mock()) + plugin1 = _make_mock_plugin("plugin1") + setattr(plugin1, "on_" + self.my_event.name, Mock()) + plugin2 = _make_mock_plugin("plugin2") + setattr(plugin2, "on_" + self.my_event.name, Mock()) + plugin3 = _make_mock_plugin("plugin3") + setattr(plugin3, "on_" + self.my_event.name, Mock()) # Create a parent mock and attach child mocks to help assert order of the calls # https://stackoverflow.com/questions/32463321/how-to-assert-method-call-order-with-python-mock @@ -193,7 +198,8 @@ def test_act_must_invoke_plugins_in_sequence(self): def test_act_must_skip_if_no_plugins_are_registered(self): # Create three plugins, and setup hook methods on it - plugin1 = _make_mock_plugin("plugin1"); setattr(plugin1, "on_"+self.my_event.name, Mock()) + plugin1 = _make_mock_plugin("plugin1") + setattr(plugin1, "on_" + self.my_event.name, Mock()) # Don't register any plugin @@ -202,7 +208,6 @@ def test_act_must_skip_if_no_plugins_are_registered(self): plugin1.on_my_event.assert_not_called() - def test_act_must_fail_on_invalid_event_type_string(self): with self.assertRaises(ValueError): @@ -214,7 +219,6 @@ def test_act_must_fail_on_invalid_event_type_object(self): self.sam_plugins.act(Mock()) def test_act_must_fail_on_invalid_event_type_enum(self): - class SomeEnum(Enum): A = 1 @@ -236,7 +240,8 @@ def test_act_must_fail_on_non_existent_hook_method(self): def test_act_must_raise_exceptions_raised_by_plugins(self): # Create a plugin but setup hook method with wrong name - plugin1 = _make_mock_plugin("plugin1"); setattr(plugin1, "on_"+self.my_event.name, Mock()) + plugin1 = _make_mock_plugin("plugin1") + setattr(plugin1, "on_" + self.my_event.name, Mock()) self.sam_plugins.register(plugin1) # Setup the hook to raise exception @@ -251,9 +256,12 @@ def test_act_must_abort_hooks_after_exception(self): # ie. after a hook raises an exception, subsequent hooks must NOT be run # Create three plugins, and setup hook methods on it - plugin1 = _make_mock_plugin("plugin1"); setattr(plugin1, "on_"+self.my_event.name, Mock()) - plugin2 = _make_mock_plugin("plugin2"); setattr(plugin2, "on_"+self.my_event.name, Mock()) - plugin3 = _make_mock_plugin("plugin3"); setattr(plugin3, "on_"+self.my_event.name, Mock()) + plugin1 = _make_mock_plugin("plugin1") + setattr(plugin1, "on_" + self.my_event.name, Mock()) + plugin2 = _make_mock_plugin("plugin2") + setattr(plugin2, "on_" + self.my_event.name, Mock()) + plugin3 = _make_mock_plugin("plugin3") + setattr(plugin3, "on_" + self.my_event.name, Mock()) # Create a parent mock and attach child mocks to help assert order of the calls # https://stackoverflow.com/questions/32463321/how-to-assert-method-call-order-with-python-mock @@ -276,8 +284,8 @@ def test_act_must_abort_hooks_after_exception(self): # Since Plugin2 raised the exception, plugin3's hook must NEVER be called parent_mock.assert_has_calls([call.plugin1_hook(), call.plugin2_hook()]) -class TestBasePlugin(TestCase): +class TestBasePlugin(TestCase): def test_initialization_should_set_name(self): name = "some name" diff --git a/tests/test_types.py b/tests/test_types.py index b4f7ba6efb..c995e0dbab 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -2,16 +2,13 @@ from samtranslator.model.types import is_type, list_of, dict_of, one_of + class DummyType(object): pass + def test_is_type_validator(): - example_properties = [ - (1, int), - ("Hello, World!", str), - ({'1': 1}, dict), - (DummyType(), DummyType) - ] + example_properties = [(1, int), ("Hello, World!", str), ({"1": 1}, dict), (DummyType(), DummyType)] for value, value_type in example_properties: # Check that is_type(value_type) passes for value @@ -22,61 +19,85 @@ def test_is_type_validator(): for _, other_type in example_properties: if value_type != other_type: validate = is_type(other_type) - assert not validate(value, should_raise=False), "is_type validator unexpectedly succeeded for type {}, value {}".format(value_type, value) + assert not validate( + value, should_raise=False + ), "is_type validator unexpectedly succeeded for type {}, value {}".format(value_type, value) with pytest.raises(TypeError): validate(value) -@pytest.mark.parametrize('value,item_type,should_pass', [ - # List of expected type - ([ 1, 2, 3 ], int, True), - # List of mixed types - ([ 1, 2, "Hello, world!", 3 ], int, False), - # Not a list - (1, int, False), -]) + +@pytest.mark.parametrize( + "value,item_type,should_pass", + [ + # List of expected type + ([1, 2, 3], int, True), + # List of mixed types + ([1, 2, "Hello, world!", 3], int, False), + # Not a list + (1, int, False), + ], +) def test_list_of_validator(value, item_type, should_pass): validate = list_of(is_type(item_type)) if should_pass: assert validate(value), "list_of validator failed for item type {}, value {}".format(item_type, value) else: - assert not validate(value, should_raise=False), "list_of validator unexpectedly succeeded for item type {}, value {}".format(item_type, value) + assert not validate( + value, should_raise=False + ), "list_of validator unexpectedly succeeded for item type {}, value {}".format(item_type, value) with pytest.raises(TypeError): validate(value) -@pytest.mark.parametrize('value,key_type,value_type,should_pass', [ - # Dict of expected types - ({ str(i): i for i in range(5) }, str, int, True), - # Dict of mixed keys - ({ '1': 1, 2: 2 }, str, int, False), - # Dict of mixed values - ({ '1': '1', '2': 2 }, str, int, False), - # Dict of mixed keys and values - ({ '1': '1', 2: 2 }, str, int, False), - # Not a dict - (('1', 2), str, int, False) -]) + +@pytest.mark.parametrize( + "value,key_type,value_type,should_pass", + [ + # Dict of expected types + ({str(i): i for i in range(5)}, str, int, True), + # Dict of mixed keys + ({"1": 1, 2: 2}, str, int, False), + # Dict of mixed values + ({"1": "1", "2": 2}, str, int, False), + # Dict of mixed keys and values + ({"1": "1", 2: 2}, str, int, False), + # Not a dict + (("1", 2), str, int, False), + ], +) def test_dict_of_validator(value, key_type, value_type, should_pass): validate = dict_of(is_type(key_type), is_type(value_type)) if should_pass: - assert validate(value), "dict_of validator failed for key type {}, item type {}, value {}".format(key_type, value_type, value) + assert validate(value), "dict_of validator failed for key type {}, item type {}, value {}".format( + key_type, value_type, value + ) else: - assert not validate(value, should_raise=False), "dict_of validator unexpectedly succeeded for key type {}, item type {}, value {}".format(key_type, value_type, value) + assert not validate( + value, should_raise=False + ), "dict_of validator unexpectedly succeeded for key type {}, item type {}, value {}".format( + key_type, value_type, value + ) with pytest.raises(TypeError): validate(value) -@pytest.mark.parametrize('value,validators,should_pass', [ - # Value of first expected type - (1, [ is_type(int), list_of(is_type(int)) ], True), - # Value of second expected type - ([ 1, 2, 3 ], [ is_type(int), list_of(is_type(int)) ], True), - # Value of neither expected type - ("Hello, World!", [ is_type(int), list_of(is_type(int)) ], False) -]) + +@pytest.mark.parametrize( + "value,validators,should_pass", + [ + # Value of first expected type + (1, [is_type(int), list_of(is_type(int))], True), + # Value of second expected type + ([1, 2, 3], [is_type(int), list_of(is_type(int))], True), + # Value of neither expected type + ("Hello, World!", [is_type(int), list_of(is_type(int))], False), + ], +) def test_one_of_validator(value, validators, should_pass): validate = one_of(*validators) if should_pass: assert validate(value), "one_of validator failed for validators {}, value {}".format(validators, value) else: - assert not validate(value, should_raise=False), "one_of validator unexpectedly succeeded for validators {}, value {}".format(validators, value) + assert not validate( + value, should_raise=False + ), "one_of validator unexpectedly succeeded for validators {}, value {}".format(validators, value) with pytest.raises(TypeError): validate(value) diff --git a/tests/translator/helpers.py b/tests/translator/helpers.py index dddf5a14fd..56d0c4b597 100644 --- a/tests/translator/helpers.py +++ b/tests/translator/helpers.py @@ -1,5 +1,2 @@ def get_template_parameter_values(): - return { - "param1": "value1", - "param2": "value2" - } + return {"param1": "value1", "param2": "value2"} diff --git a/tests/translator/model/preferences/test_deployment_preference.py b/tests/translator/model/preferences/test_deployment_preference.py index 332a8dc6b0..19331b3e42 100644 --- a/tests/translator/model/preferences/test_deployment_preference.py +++ b/tests/translator/model/preferences/test_deployment_preference.py @@ -5,52 +5,68 @@ class TestDeploymentPreference(TestCase): - def setUp(self): - self.deployment_type = 'AllAtOnce' - self.pre_traffic_hook = 'pre_traffic_function_ref' - self.post_traffic_hook = 'post_traffic_function_ref' - self.alarms = ['alarm1ref', 'alarm2ref'] + self.deployment_type = "AllAtOnce" + self.pre_traffic_hook = "pre_traffic_function_ref" + self.post_traffic_hook = "post_traffic_function_ref" + self.alarms = ["alarm1ref", "alarm2ref"] self.role = {"Ref": "MyRole"} self.trigger_configurations = { - "TriggerEvents": [ - "DeploymentSuccess", - "DeploymentFailure" - ], - "TriggerTargetArn": { - "Ref": "MySNSTopic" - }, - "TriggerName": "TestTrigger" - } - self.expected_deployment_preference = DeploymentPreference(self.deployment_type, self.pre_traffic_hook, - self.post_traffic_hook, self.alarms, True, self.role, - self.trigger_configurations) + "TriggerEvents": ["DeploymentSuccess", "DeploymentFailure"], + "TriggerTargetArn": {"Ref": "MySNSTopic"}, + "TriggerName": "TestTrigger", + } + self.expected_deployment_preference = DeploymentPreference( + self.deployment_type, + self.pre_traffic_hook, + self.post_traffic_hook, + self.alarms, + True, + self.role, + self.trigger_configurations, + ) def test_from_dict_with_intrinsic_function_type(self): type = {"Ref": "SomeType"} - expected_deployment_preference = DeploymentPreference(type, self.pre_traffic_hook, - self.post_traffic_hook, self.alarms, True, self.role, - self.trigger_configurations) + expected_deployment_preference = DeploymentPreference( + type, + self.pre_traffic_hook, + self.post_traffic_hook, + self.alarms, + True, + self.role, + self.trigger_configurations, + ) deployment_preference_yaml_dict = dict() - deployment_preference_yaml_dict['Type'] = type - deployment_preference_yaml_dict['Hooks'] = {'PreTraffic': self.pre_traffic_hook, 'PostTraffic': self.post_traffic_hook} - deployment_preference_yaml_dict['Alarms'] = self.alarms - deployment_preference_yaml_dict['Role'] = self.role - deployment_preference_yaml_dict['TriggerConfigurations'] = self.trigger_configurations - deployment_preference_from_yaml_dict = DeploymentPreference.from_dict('logical_id', deployment_preference_yaml_dict) + deployment_preference_yaml_dict["Type"] = type + deployment_preference_yaml_dict["Hooks"] = { + "PreTraffic": self.pre_traffic_hook, + "PostTraffic": self.post_traffic_hook, + } + deployment_preference_yaml_dict["Alarms"] = self.alarms + deployment_preference_yaml_dict["Role"] = self.role + deployment_preference_yaml_dict["TriggerConfigurations"] = self.trigger_configurations + deployment_preference_from_yaml_dict = DeploymentPreference.from_dict( + "logical_id", deployment_preference_yaml_dict + ) self.assertEqual(expected_deployment_preference, deployment_preference_from_yaml_dict) def test_from_dict(self): deployment_preference_yaml_dict = dict() - deployment_preference_yaml_dict['Type'] = self.deployment_type - deployment_preference_yaml_dict['Hooks'] = {'PreTraffic': self.pre_traffic_hook, 'PostTraffic': self.post_traffic_hook} - deployment_preference_yaml_dict['Alarms'] = self.alarms - deployment_preference_yaml_dict['Role'] = self.role - deployment_preference_yaml_dict['TriggerConfigurations'] = self.trigger_configurations - deployment_preference_from_yaml_dict = DeploymentPreference.from_dict('logical_id', deployment_preference_yaml_dict) + deployment_preference_yaml_dict["Type"] = self.deployment_type + deployment_preference_yaml_dict["Hooks"] = { + "PreTraffic": self.pre_traffic_hook, + "PostTraffic": self.post_traffic_hook, + } + deployment_preference_yaml_dict["Alarms"] = self.alarms + deployment_preference_yaml_dict["Role"] = self.role + deployment_preference_yaml_dict["TriggerConfigurations"] = self.trigger_configurations + deployment_preference_from_yaml_dict = DeploymentPreference.from_dict( + "logical_id", deployment_preference_yaml_dict + ) self.assertEqual(self.expected_deployment_preference, deployment_preference_from_yaml_dict) @@ -58,16 +74,17 @@ def test_from_dict_with_disabled_preference_does_not_require_other_parameters(se expected_deployment_preference = DeploymentPreference(None, None, None, None, False, None, None) deployment_preference_yaml_dict = dict() - deployment_preference_yaml_dict['Enabled'] = False - deployment_preference_from_yaml_dict = DeploymentPreference.from_dict('logical_id', - deployment_preference_yaml_dict) + deployment_preference_yaml_dict["Enabled"] = False + deployment_preference_from_yaml_dict = DeploymentPreference.from_dict( + "logical_id", deployment_preference_yaml_dict + ) self.assertEqual(expected_deployment_preference, deployment_preference_from_yaml_dict) def test_from_dict_with_non_dict_hooks_raises_invalid_resource_exception(self): with self.assertRaises(InvalidResourceException): - DeploymentPreference.from_dict('logical_id', {'Type': 'Canary', 'Hooks': 'badhook'}) + DeploymentPreference.from_dict("logical_id", {"Type": "Canary", "Hooks": "badhook"}) def test_from_dict_with_missing_type_raises_invalid_resource_exception(self): with self.assertRaises(InvalidResourceException): - DeploymentPreference.from_dict('logical_id', dict()) + DeploymentPreference.from_dict("logical_id", dict()) diff --git a/tests/translator/model/preferences/test_deployment_preference_collection.py b/tests/translator/model/preferences/test_deployment_preference_collection.py index 4415c01c2f..6307248087 100644 --- a/tests/translator/model/preferences/test_deployment_preference_collection.py +++ b/tests/translator/model/preferences/test_deployment_preference_collection.py @@ -13,186 +13,179 @@ class TestDeploymentPreferenceCollection(TestCase): def setup_method(self, method): - self.deployment_type_global = 'AllAtOnce' - self.alarms_global = [{'Ref': 'alarm1'}, {'Ref': 'alarm2'}] - self.post_traffic_host_global = 'post_traffic_function_ref' - self.pre_traffic_hook_global = 'pre_traffic_function_ref' - self.function_logical_id = 'FunctionLogicalId' + self.deployment_type_global = "AllAtOnce" + self.alarms_global = [{"Ref": "alarm1"}, {"Ref": "alarm2"}] + self.post_traffic_host_global = "post_traffic_function_ref" + self.pre_traffic_hook_global = "pre_traffic_function_ref" + self.function_logical_id = "FunctionLogicalId" - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_when_no_global_dict_each_local_deployment_preference_requires_parameters(self): with self.assertRaises(InvalidResourceException): - DeploymentPreferenceCollection().add('', dict()) + DeploymentPreferenceCollection().add("", dict()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_add_when_logical_id_previously_added_raises_value_error(self): with self.assertRaises(ValueError): deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add('1', {'Type': 'Canary'}) - deployment_preference_collection.add('1', {'Type': 'Linear'}) + deployment_preference_collection.add("1", {"Type": "Canary"}) + deployment_preference_collection.add("1", {"Type": "Linear"}) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_codedeploy_application(self): expected_codedeploy_application_resource = CodeDeployApplication(CODEDEPLOY_APPLICATION_LOGICAL_ID) - expected_codedeploy_application_resource.ComputePlatform = 'Lambda' + expected_codedeploy_application_resource.ComputePlatform = "Lambda" - self.assertEqual(DeploymentPreferenceCollection().codedeploy_application.to_dict(), - expected_codedeploy_application_resource.to_dict()) + self.assertEqual( + DeploymentPreferenceCollection().codedeploy_application.to_dict(), + expected_codedeploy_application_resource.to_dict(), + ) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_codedeploy_iam_role(self): - expected_codedeploy_iam_role = IAMRole('CodeDeployServiceRole') + expected_codedeploy_iam_role = IAMRole("CodeDeployServiceRole") expected_codedeploy_iam_role.AssumeRolePolicyDocument = { - 'Version': '2012-10-17', - 'Statement': [{ - 'Action': ['sts:AssumeRole'], - 'Effect': 'Allow', - 'Principal': {'Service': ['codedeploy.amazonaws.com']} - }] + "Version": "2012-10-17", + "Statement": [ + { + "Action": ["sts:AssumeRole"], + "Effect": "Allow", + "Principal": {"Service": ["codedeploy.amazonaws.com"]}, + } + ], } - expected_codedeploy_iam_role.ManagedPolicyArns = ['arn:aws:iam::aws:policy/service-role/AWSCodeDeployRoleForLambda'] + expected_codedeploy_iam_role.ManagedPolicyArns = [ + "arn:aws:iam::aws:policy/service-role/AWSCodeDeployRoleForLambda" + ] - self.assertEqual(DeploymentPreferenceCollection().codedeploy_iam_role.to_dict(), - expected_codedeploy_iam_role.to_dict()) + self.assertEqual( + DeploymentPreferenceCollection().codedeploy_iam_role.to_dict(), expected_codedeploy_iam_role.to_dict() + ) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_deployment_group_with_minimal_parameters(self): - expected_deployment_group = CodeDeployDeploymentGroup(self.function_logical_id + 'DeploymentGroup') - expected_deployment_group.ApplicationName = {'Ref': CODEDEPLOY_APPLICATION_LOGICAL_ID} - expected_deployment_group.AutoRollbackConfiguration = {'Enabled': True, - 'Events': ['DEPLOYMENT_FAILURE', - 'DEPLOYMENT_STOP_ON_ALARM', - 'DEPLOYMENT_STOP_ON_REQUEST']} + expected_deployment_group = CodeDeployDeploymentGroup(self.function_logical_id + "DeploymentGroup") + expected_deployment_group.ApplicationName = {"Ref": CODEDEPLOY_APPLICATION_LOGICAL_ID} + expected_deployment_group.AutoRollbackConfiguration = { + "Enabled": True, + "Events": ["DEPLOYMENT_FAILURE", "DEPLOYMENT_STOP_ON_ALARM", "DEPLOYMENT_STOP_ON_REQUEST"], + } expected_deployment_group.DeploymentConfigName = { - 'Fn::Sub': [ - 'CodeDeployDefault.Lambda${ConfigName}', - { - 'ConfigName': self.deployment_type_global - } - ] + "Fn::Sub": ["CodeDeployDefault.Lambda${ConfigName}", {"ConfigName": self.deployment_type_global}] } - expected_deployment_group.DeploymentStyle = {'DeploymentType': 'BLUE_GREEN', - 'DeploymentOption': 'WITH_TRAFFIC_CONTROL'} - expected_deployment_group.ServiceRoleArn = {'Fn::GetAtt': [CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID, 'Arn']} + expected_deployment_group.DeploymentStyle = { + "DeploymentType": "BLUE_GREEN", + "DeploymentOption": "WITH_TRAFFIC_CONTROL", + } + expected_deployment_group.ServiceRoleArn = {"Fn::GetAtt": [CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID, "Arn"]} deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add(self.function_logical_id, {'Type': self.deployment_type_global}) + deployment_preference_collection.add(self.function_logical_id, {"Type": self.deployment_type_global}) deployment_group = deployment_preference_collection.deployment_group(self.function_logical_id) - self.assertEqual(deployment_group.to_dict(), - expected_deployment_group.to_dict()) + self.assertEqual(deployment_group.to_dict(), expected_deployment_group.to_dict()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_deployment_preference_with_codedeploy_custom_configuration(self): deployment_type = "TestDeploymentConfiguration" deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add(self.function_logical_id, {'Type': deployment_type}) + deployment_preference_collection.add(self.function_logical_id, {"Type": deployment_type}) deployment_group = deployment_preference_collection.deployment_group(self.function_logical_id) self.assertEqual(deployment_type, deployment_group.DeploymentConfigName) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_deployment_preference_with_codedeploy_predifined_configuration(self): deployment_type = "Canary10Percent5Minutes" expected_deployment_config_name = { - 'Fn::Sub': [ - 'CodeDeployDefault.Lambda${ConfigName}', - { - 'ConfigName': deployment_type - } - ] + "Fn::Sub": ["CodeDeployDefault.Lambda${ConfigName}", {"ConfigName": deployment_type}] } deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add(self.function_logical_id, {'Type': deployment_type}) + deployment_preference_collection.add(self.function_logical_id, {"Type": deployment_type}) deployment_group = deployment_preference_collection.deployment_group(self.function_logical_id) - print(deployment_group.DeploymentConfigName) + print (deployment_group.DeploymentConfigName) self.assertEqual(expected_deployment_config_name, deployment_group.DeploymentConfigName) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_deployment_preference_with_conditional_custom_configuration(self): - deployment_type = {'Fn::If': ['IsDevEnv', {'Fn::If': - ['IsDevEnv1', 'AllAtOnce', 'TestDeploymentConfiguration']}, - 'Canary10Percent15Minutes']} - - expected_deployment_config_name = {'Fn::If': - ['IsDevEnv', {'Fn::If': - ['IsDevEnv1', {'Fn::Sub': [ - 'CodeDeployDefault.Lambda${ConfigName}', - { - 'ConfigName': 'AllAtOnce' - } - ] - }, - 'TestDeploymentConfiguration']}, - {'Fn::Sub': [ - 'CodeDeployDefault.Lambda${ConfigName}', - { - 'ConfigName': 'Canary10Percent15Minutes' - } - ] - } - ] - } + deployment_type = { + "Fn::If": [ + "IsDevEnv", + {"Fn::If": ["IsDevEnv1", "AllAtOnce", "TestDeploymentConfiguration"]}, + "Canary10Percent15Minutes", + ] + } + + expected_deployment_config_name = { + "Fn::If": [ + "IsDevEnv", + { + "Fn::If": [ + "IsDevEnv1", + {"Fn::Sub": ["CodeDeployDefault.Lambda${ConfigName}", {"ConfigName": "AllAtOnce"}]}, + "TestDeploymentConfiguration", + ] + }, + {"Fn::Sub": ["CodeDeployDefault.Lambda${ConfigName}", {"ConfigName": "Canary10Percent15Minutes"}]}, + ] + } deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add(self.function_logical_id, {'Type': deployment_type}) + deployment_preference_collection.add(self.function_logical_id, {"Type": deployment_type}) deployment_group = deployment_preference_collection.deployment_group(self.function_logical_id) - print(deployment_group.DeploymentConfigName) + print (deployment_group.DeploymentConfigName) self.assertEqual(expected_deployment_config_name, deployment_group.DeploymentConfigName) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_deployment_group_with_all_parameters(self): - expected_deployment_group = CodeDeployDeploymentGroup(self.function_logical_id + 'DeploymentGroup') - expected_deployment_group.AlarmConfiguration = {'Enabled': True, - 'Alarms': [{'Name': {'Ref': 'alarm1'}}, - {'Name': {'Ref': 'alarm2'}}]} - expected_deployment_group.ApplicationName = {'Ref': CODEDEPLOY_APPLICATION_LOGICAL_ID} - expected_deployment_group.AutoRollbackConfiguration = {'Enabled': True, - 'Events': ['DEPLOYMENT_FAILURE', - 'DEPLOYMENT_STOP_ON_ALARM', - 'DEPLOYMENT_STOP_ON_REQUEST']} + expected_deployment_group = CodeDeployDeploymentGroup(self.function_logical_id + "DeploymentGroup") + expected_deployment_group.AlarmConfiguration = { + "Enabled": True, + "Alarms": [{"Name": {"Ref": "alarm1"}}, {"Name": {"Ref": "alarm2"}}], + } + expected_deployment_group.ApplicationName = {"Ref": CODEDEPLOY_APPLICATION_LOGICAL_ID} + expected_deployment_group.AutoRollbackConfiguration = { + "Enabled": True, + "Events": ["DEPLOYMENT_FAILURE", "DEPLOYMENT_STOP_ON_ALARM", "DEPLOYMENT_STOP_ON_REQUEST"], + } expected_deployment_group.DeploymentConfigName = { - 'Fn::Sub': [ - 'CodeDeployDefault.Lambda${ConfigName}', - { - 'ConfigName': self.deployment_type_global - } - ] + "Fn::Sub": ["CodeDeployDefault.Lambda${ConfigName}", {"ConfigName": self.deployment_type_global}] + } + expected_deployment_group.DeploymentStyle = { + "DeploymentType": "BLUE_GREEN", + "DeploymentOption": "WITH_TRAFFIC_CONTROL", } - expected_deployment_group.DeploymentStyle = {'DeploymentType': 'BLUE_GREEN', - 'DeploymentOption': 'WITH_TRAFFIC_CONTROL'} - expected_deployment_group.ServiceRoleArn = {'Fn::GetAtt': [CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID, 'Arn']} + expected_deployment_group.ServiceRoleArn = {"Fn::GetAtt": [CODE_DEPLOY_SERVICE_ROLE_LOGICAL_ID, "Arn"]} deployment_preference_collection = DeploymentPreferenceCollection() deployment_preference_collection.add(self.function_logical_id, self.global_deployment_preference_yaml_dict()) deployment_group = deployment_preference_collection.deployment_group(self.function_logical_id) - self.assertEqual(deployment_group.to_dict(), - expected_deployment_group.to_dict()) + self.assertEqual(deployment_group.to_dict(), expected_deployment_group.to_dict()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_update_policy_with_minimal_parameters(self): expected_update_policy = { - 'CodeDeployLambdaAliasUpdate': { - 'ApplicationName': {'Ref': CODEDEPLOY_APPLICATION_LOGICAL_ID}, - 'DeploymentGroupName': {'Ref': self.function_logical_id + 'DeploymentGroup'}, + "CodeDeployLambdaAliasUpdate": { + "ApplicationName": {"Ref": CODEDEPLOY_APPLICATION_LOGICAL_ID}, + "DeploymentGroupName": {"Ref": self.function_logical_id + "DeploymentGroup"}, } } deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add(self.function_logical_id, {'Type': 'CANARY'}) + deployment_preference_collection.add(self.function_logical_id, {"Type": "CANARY"}) update_policy = deployment_preference_collection.update_policy(self.function_logical_id) self.assertEqual(expected_update_policy, update_policy.to_dict()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_update_policy_with_all_parameters(self): expected_update_polcy = { - 'CodeDeployLambdaAliasUpdate': { - 'ApplicationName': {'Ref': CODEDEPLOY_APPLICATION_LOGICAL_ID}, - 'DeploymentGroupName': {'Ref': self.function_logical_id + 'DeploymentGroup'}, - 'BeforeAllowTrafficHook': self.pre_traffic_hook_global, - 'AfterAllowTrafficHook': self.post_traffic_host_global, + "CodeDeployLambdaAliasUpdate": { + "ApplicationName": {"Ref": CODEDEPLOY_APPLICATION_LOGICAL_ID}, + "DeploymentGroupName": {"Ref": self.function_logical_id + "DeploymentGroup"}, + "BeforeAllowTrafficHook": self.pre_traffic_hook_global, + "AfterAllowTrafficHook": self.post_traffic_host_global, } } @@ -202,56 +195,61 @@ def test_update_policy_with_all_parameters(self): self.assertEqual(expected_update_polcy, update_policy.to_dict()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_any_enabled_true_if_one_of_three_enabled(self): deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add('1', {'Type': 'LINEAR'}) - deployment_preference_collection.add('2', {'Type': 'LINEAR', 'Enabled': False}) - deployment_preference_collection.add('3', {'Type': 'CANARY', 'Enabled': False}) + deployment_preference_collection.add("1", {"Type": "LINEAR"}) + deployment_preference_collection.add("2", {"Type": "LINEAR", "Enabled": False}) + deployment_preference_collection.add("3", {"Type": "CANARY", "Enabled": False}) self.assertTrue(deployment_preference_collection.any_enabled()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_any_enabled_true_if_all_of_three_enabled(self): deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add('1', {'Type': 'LINEAR'}) - deployment_preference_collection.add('2', {'Type': 'LINEAR', 'Enabled': True}) - deployment_preference_collection.add('3', {'Type': 'CANARY', 'Enabled': True}) + deployment_preference_collection.add("1", {"Type": "LINEAR"}) + deployment_preference_collection.add("2", {"Type": "LINEAR", "Enabled": True}) + deployment_preference_collection.add("3", {"Type": "CANARY", "Enabled": True}) self.assertTrue(deployment_preference_collection.any_enabled()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_any_enabled_false_if_all_of_three_disabled(self): deployment_preference_collection = DeploymentPreferenceCollection() - deployment_preference_collection.add('1', {'Type': 'Linear', 'Enabled': False}) - deployment_preference_collection.add('2', {'Type': 'LINEAR', 'Enabled': False}) - deployment_preference_collection.add('3', {'Type': 'CANARY', 'Enabled': False}) + deployment_preference_collection.add("1", {"Type": "Linear", "Enabled": False}) + deployment_preference_collection.add("2", {"Type": "LINEAR", "Enabled": False}) + deployment_preference_collection.add("3", {"Type": "CANARY", "Enabled": False}) self.assertFalse(deployment_preference_collection.any_enabled()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_enabled_logical_ids_returns_one_if_one_of_three_enabled(self): deployment_preference_collection = DeploymentPreferenceCollection() - enabled_logical_id = '1' - deployment_preference_collection.add(enabled_logical_id, {'Type': 'LINEAR'}) - deployment_preference_collection.add('2', {'Type': 'LINEAR', 'Enabled': False}) - deployment_preference_collection.add('3', {'Type': 'CANARY', 'Enabled': False}) + enabled_logical_id = "1" + deployment_preference_collection.add(enabled_logical_id, {"Type": "LINEAR"}) + deployment_preference_collection.add("2", {"Type": "LINEAR", "Enabled": False}) + deployment_preference_collection.add("3", {"Type": "CANARY", "Enabled": False}) enabled_logical_ids = deployment_preference_collection.enabled_logical_ids() self.assertEqual(1, len(enabled_logical_ids)) - self.assertEqual(enabled_logical_id, - enabled_logical_ids[0]) + self.assertEqual(enabled_logical_id, enabled_logical_ids[0]) def global_deployment_preference_yaml_dict(self): deployment_preference_yaml_dict = dict() - deployment_preference_yaml_dict['Type'] = self.deployment_type_global - deployment_preference_yaml_dict['Hooks'] = {'PreTraffic': self.pre_traffic_hook_global, - 'PostTraffic': self.post_traffic_host_global} - deployment_preference_yaml_dict['Alarms'] = self.alarms_global + deployment_preference_yaml_dict["Type"] = self.deployment_type_global + deployment_preference_yaml_dict["Hooks"] = { + "PreTraffic": self.pre_traffic_hook_global, + "PostTraffic": self.post_traffic_host_global, + } + deployment_preference_yaml_dict["Alarms"] = self.alarms_global return deployment_preference_yaml_dict def global_deployment_preference(self): - expected_deployment_preference = DeploymentPreference(self.deployment_type_global, - self.pre_traffic_hook_global, - self.post_traffic_host_global, self.alarms_global, True) + expected_deployment_preference = DeploymentPreference( + self.deployment_type_global, + self.pre_traffic_hook_global, + self.post_traffic_host_global, + self.alarms_global, + True, + ) return expected_deployment_preference diff --git a/tests/translator/model/test_update_policy.py b/tests/translator/model/test_update_policy.py index 113b0ddff4..f2ad9362aa 100644 --- a/tests/translator/model/test_update_policy.py +++ b/tests/translator/model/test_update_policy.py @@ -6,27 +6,26 @@ class TestUpdatePolicy(TestCase): def test_to_dict_only_application_and_deployment_group(self): expected_dict = { - 'CodeDeployLambdaAliasUpdate': { - 'ApplicationName': 'application_name', - 'DeploymentGroupName': 'deployment_group_name' + "CodeDeployLambdaAliasUpdate": { + "ApplicationName": "application_name", + "DeploymentGroupName": "deployment_group_name", } } - update_policy_dict = UpdatePolicy('application_name', 'deployment_group_name', None, None).to_dict() + update_policy_dict = UpdatePolicy("application_name", "deployment_group_name", None, None).to_dict() self.assertEqual(expected_dict, update_policy_dict) def test_to_dict_all_fields(self): expected_dict = { - 'CodeDeployLambdaAliasUpdate': { - 'ApplicationName': 'application_name', - 'DeploymentGroupName': 'deployment_group_name', - 'BeforeAllowTrafficHook': 'before_allow_traffic_hook', - 'AfterAllowTrafficHook': 'after_allow_traffic_hook' + "CodeDeployLambdaAliasUpdate": { + "ApplicationName": "application_name", + "DeploymentGroupName": "deployment_group_name", + "BeforeAllowTrafficHook": "before_allow_traffic_hook", + "AfterAllowTrafficHook": "after_allow_traffic_hook", } } - update_policy_dict = UpdatePolicy('application_name', - 'deployment_group_name', - 'before_allow_traffic_hook', - 'after_allow_traffic_hook').to_dict() + update_policy_dict = UpdatePolicy( + "application_name", "deployment_group_name", "before_allow_traffic_hook", "after_allow_traffic_hook" + ).to_dict() self.assertEqual(expected_dict, update_policy_dict) diff --git a/tests/translator/test_api_resource.py b/tests/translator/test_api_resource.py index 1c77c07220..fcf2c3e646 100644 --- a/tests/translator/test_api_resource.py +++ b/tests/translator/test_api_resource.py @@ -10,28 +10,26 @@ mock_policy_loader = MagicMock() mock_policy_loader.load.return_value = { - 'AmazonDynamoDBFullAccess': 'arn:aws:iam::aws:policy/AmazonDynamoDBFullAccess', - 'AmazonDynamoDBReadOnlyAccess': 'arn:aws:iam::aws:policy/AmazonDynamoDBReadOnlyAccess', - 'AWSLambdaRole': 'arn:aws:iam::aws:policy/service-role/AWSLambdaRole', + "AmazonDynamoDBFullAccess": "arn:aws:iam::aws:policy/AmazonDynamoDBFullAccess", + "AmazonDynamoDBReadOnlyAccess": "arn:aws:iam::aws:policy/AmazonDynamoDBReadOnlyAccess", + "AWSLambdaRole": "arn:aws:iam::aws:policy/service-role/AWSLambdaRole", } -@patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + +@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_redeploy_explicit_api(): """ Test to verify that we will redeploy an API when Swagger document changes :return: """ manifest = { - 'Transform': 'AWS::Serverless-2016-10-31', - 'Resources': { - 'ExplicitApi': { - 'Type': "AWS::Serverless::Api", - "Properties": { - "StageName": "prod", - "DefinitionUri": "s3://mybucket/swagger.json?versionId=123" - } + "Transform": "AWS::Serverless-2016-10-31", + "Resources": { + "ExplicitApi": { + "Type": "AWS::Serverless::Api", + "Properties": {"StageName": "prod", "DefinitionUri": "s3://mybucket/swagger.json?versionId=123"}, } - } + }, } original_deployment_ids = translate_and_find_deployment_ids(manifest) @@ -46,46 +44,30 @@ def test_redeploy_explicit_api(): assert updated_deployment_ids == translate_and_find_deployment_ids(manifest) -@patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) +@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_redeploy_implicit_api(): manifest = { - 'Transform': 'AWS::Serverless-2016-10-31', - 'Resources': { - 'FirstLambdaFunction': { - 'Type': "AWS::Serverless::Function", + "Transform": "AWS::Serverless-2016-10-31", + "Resources": { + "FirstLambdaFunction": { + "Type": "AWS::Serverless::Function", "Properties": { "CodeUri": "s3://bucket/code.zip", "Handler": "index.handler", "Runtime": "nodejs4.3", - "Events": { - "MyApi": { - "Type": "Api", - "Properties": { - "Path": "/first", - "Method": "get" - } - } - } - } + "Events": {"MyApi": {"Type": "Api", "Properties": {"Path": "/first", "Method": "get"}}}, + }, }, - 'SecondLambdaFunction': { - 'Type': "AWS::Serverless::Function", + "SecondLambdaFunction": { + "Type": "AWS::Serverless::Function", "Properties": { "CodeUri": "s3://bucket/code.zip", "Handler": "index.handler", "Runtime": "nodejs4.3", - "Events": { - "MyApi": { - "Type": "Api", - "Properties": { - "Path": "/second", - "Method": "get" - } - } - } - } - } - } + "Events": {"MyApi": {"Type": "Api", "Properties": {"Path": "/second", "Method": "get"}}}, + }, + }, + }, } original_deployment_ids = translate_and_find_deployment_ids(manifest) @@ -105,11 +87,11 @@ def test_redeploy_implicit_api(): assert second_updated_deployment_ids == translate_and_find_deployment_ids(manifest) -@patch('boto3.session.Session.region_name', 'ap-southeast-1') +@patch("boto3.session.Session.region_name", "ap-southeast-1") def translate_and_find_deployment_ids(manifest): parameter_values = get_template_parameter_values() output_fragment = transform(manifest, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + print (json.dumps(output_fragment, indent=2)) deployment_ids = set() for key, value in output_fragment["Resources"].items(): @@ -120,7 +102,6 @@ def translate_and_find_deployment_ids(manifest): class TestApiGatewayDeploymentResource(TestCase): - @patch("samtranslator.translator.logical_id_generator.LogicalIdGenerator") def test_make_auto_deployable_with_swagger_dict(self, LogicalIdGeneratorMock): prefix = "prefix" @@ -140,7 +121,7 @@ def test_make_auto_deployable_with_swagger_dict(self, LogicalIdGeneratorMock): LogicalIdGeneratorMock.assert_called_once_with(prefix, str(swagger)) generator_mock.gen.assert_called_once_with() - generator_mock.get_hash.assert_called_once_with(length=40) # getting full SHA + generator_mock.get_hash.assert_called_once_with(length=40) # getting full SHA stage.update_deployment_ref.assert_called_once_with(id_val) @patch("samtranslator.translator.logical_id_generator.LogicalIdGenerator") diff --git a/tests/translator/test_function_resources.py b/tests/translator/test_function_resources.py index 3c42a8e62b..11a91e72ef 100644 --- a/tests/translator/test_function_resources.py +++ b/tests/translator/test_function_resources.py @@ -9,8 +9,6 @@ class TestVersionsAndAliases(TestCase): - - def setUp(self): self.intrinsics_resolver_mock = Mock() @@ -21,17 +19,13 @@ def setUp(self): self.code_uri = "s3://bucket/key?versionId=version" self.func_dict = { "Type": "AWS::Serverless::Function", - "Properties": { - "CodeUri": self.code_uri, - "Runtime": "nodejs4.3", - "Handler": "index.handler" - } + "Properties": {"CodeUri": self.code_uri, "Runtime": "nodejs4.3", "Handler": "index.handler"}, } self.sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=self.func_dict) self.lambda_func = self._make_lambda_function(self.sam_func.logical_id) self.lambda_version = self._make_lambda_version("VersionLogicalId", self.sam_func) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") @patch.object(SamFunction, "_get_resolved_alias_name") def test_sam_function_with_alias(self, get_resolved_alias_name_mock): alias_name = "AliasName" @@ -41,8 +35,8 @@ def test_sam_function_with_alias(self, get_resolved_alias_name_mock): "CodeUri": self.code_uri, "Runtime": "nodejs4.3", "Handler": "index.handler", - "AutoPublishAlias": alias_name - } + "AutoPublishAlias": alias_name, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -51,7 +45,11 @@ def test_sam_function_with_alias(self, get_resolved_alias_name_mock): kwargs["managed_policy_map"] = {"a": "b"} kwargs["event_resources"] = [] kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock - self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = {"S3Bucket": "bucket", "S3Key": "key", "S3ObjectVersion": "version"} + self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = { + "S3Bucket": "bucket", + "S3Key": "key", + "S3ObjectVersion": "version", + } get_resolved_alias_name_mock.return_value = alias_name resources = sam_func.to_cloudformation(**kwargs) @@ -69,7 +67,9 @@ def test_sam_function_with_alias(self, get_resolved_alias_name_mock): # We don't need to do any deeper validation here because there is a separate SAM template -> CFN template conversion test # that will care of validating all properties & connections - sam_func._get_resolved_alias_name.assert_called_once_with("AutoPublishAlias", alias_name, self.intrinsics_resolver_mock) + sam_func._get_resolved_alias_name.assert_called_once_with( + "AutoPublishAlias", alias_name, self.intrinsics_resolver_mock + ) def test_sam_function_with_alias_cannot_be_list(self): @@ -78,7 +78,7 @@ def test_sam_function_with_alias_cannot_be_list(self): self.func_dict["Properties"]["AutoPublishAlias"] = ["a", "b"] SamFunction.from_dict(logical_id="foo", resource_dict=self.func_dict) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") @patch.object(SamFunction, "_get_resolved_alias_name") def test_sam_function_with_deployment_preference(self, get_resolved_alias_name_mock): deploy_preference_dict = {"Type": "LINEAR"} @@ -90,8 +90,8 @@ def test_sam_function_with_deployment_preference(self, get_resolved_alias_name_m "Runtime": "nodejs4.3", "Handler": "index.handler", "AutoPublishAlias": alias_name, - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -102,16 +102,18 @@ def test_sam_function_with_deployment_preference(self, get_resolved_alias_name_m kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock deployment_preference_collection = self._make_deployment_preference_collection() - kwargs['deployment_preference_collection'] = deployment_preference_collection + kwargs["deployment_preference_collection"] = deployment_preference_collection get_resolved_alias_name_mock.return_value = alias_name - self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = {"S3Bucket": "bucket", "S3Key": "key", - "S3ObjectVersion": "version"} + self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = { + "S3Bucket": "bucket", + "S3Key": "key", + "S3ObjectVersion": "version", + } resources = sam_func.to_cloudformation(**kwargs) deployment_preference_collection.update_policy.assert_called_once_with(self.sam_func.logical_id) - deployment_preference_collection.add.assert_called_once_with(self.sam_func.logical_id, - deploy_preference_dict) + deployment_preference_collection.add.assert_called_once_with(self.sam_func.logical_id, deploy_preference_dict) aliases = [r.to_dict() for r in resources if r.resource_type == LambdaAlias.resource_type] @@ -119,7 +121,9 @@ def test_sam_function_with_deployment_preference(self, get_resolved_alias_name_m self.assertEqual(list(aliases[0].values())[0]["UpdatePolicy"], self.update_policy().to_dict()) @patch.object(SamFunction, "_get_resolved_alias_name") - def test_sam_function_with_deployment_preference_missing_collection_raises_error(self, get_resolved_alias_name_mock): + def test_sam_function_with_deployment_preference_missing_collection_raises_error( + self, get_resolved_alias_name_mock + ): alias_name = "AliasName" deploy_preference_dict = {"Type": "LINEAR"} func = { @@ -129,8 +133,8 @@ def test_sam_function_with_deployment_preference_missing_collection_raises_error "Runtime": "nodejs4.3", "Handler": "index.handler", "AutoPublishAlias": alias_name, - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -140,16 +144,21 @@ def test_sam_function_with_deployment_preference_missing_collection_raises_error kwargs["event_resources"] = [] kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock - self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = {"S3Bucket": "bucket", "S3Key": "key", - "S3ObjectVersion": "version"} + self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = { + "S3Bucket": "bucket", + "S3Key": "key", + "S3ObjectVersion": "version", + } get_resolved_alias_name_mock.return_value = alias_name with self.assertRaises(ValueError): sam_func.to_cloudformation(**kwargs) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") @patch.object(SamFunction, "_get_resolved_alias_name") - def test_sam_function_with_disabled_deployment_preference_does_not_add_update_policy(self, get_resolved_alias_name_mock): + def test_sam_function_with_disabled_deployment_preference_does_not_add_update_policy( + self, get_resolved_alias_name_mock + ): alias_name = "AliasName" enabled = False deploy_preference_dict = {"Enabled": enabled} @@ -160,8 +169,8 @@ def test_sam_function_with_disabled_deployment_preference_does_not_add_update_po "Runtime": "nodejs4.3", "Handler": "index.handler", "AutoPublishAlias": alias_name, - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -172,10 +181,11 @@ def test_sam_function_with_disabled_deployment_preference_does_not_add_update_po kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock preference_collection = self._make_deployment_preference_collection() - preference_collection.get.return_value = DeploymentPreference.from_dict(sam_func.logical_id, - deploy_preference_dict) + preference_collection.get.return_value = DeploymentPreference.from_dict( + sam_func.logical_id, deploy_preference_dict + ) - kwargs['deployment_preference_collection'] = preference_collection + kwargs["deployment_preference_collection"] = preference_collection self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = enabled get_resolved_alias_name_mock.return_value = alias_name @@ -196,8 +206,8 @@ def test_sam_function_cannot_be_with_deployment_preference_without_alias(self): "CodeUri": self.code_uri, "Runtime": "nodejs4.3", "Handler": "index.handler", - "DeploymentPreference": {"Type": "LINEAR"} - } + "DeploymentPreference": {"Type": "LINEAR"}, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -205,10 +215,10 @@ def test_sam_function_cannot_be_with_deployment_preference_without_alias(self): kwargs = dict() kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock - kwargs['deployment_preference_collection'] = self._make_deployment_preference_collection() + kwargs["deployment_preference_collection"] = self._make_deployment_preference_collection() sam_func.to_cloudformation(**kwargs) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") def test_sam_function_without_alias_allows_disabled_deployment_preference(self): enabled = False deploy_preference_dict = {"Enabled": enabled} @@ -218,8 +228,8 @@ def test_sam_function_without_alias_allows_disabled_deployment_preference(self): "CodeUri": self.code_uri, "Runtime": "nodejs4.3", "Handler": "index.handler", - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -231,10 +241,11 @@ def test_sam_function_without_alias_allows_disabled_deployment_preference(self): kwargs["mappings_resolver"] = self.mappings_resolver_mock preference_collection = self._make_deployment_preference_collection() - preference_collection.get.return_value = DeploymentPreference.from_dict(sam_func.logical_id, - deploy_preference_dict) + preference_collection.get.return_value = DeploymentPreference.from_dict( + sam_func.logical_id, deploy_preference_dict + ) - kwargs['deployment_preference_collection'] = preference_collection + kwargs["deployment_preference_collection"] = preference_collection self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = enabled resources = sam_func.to_cloudformation(**kwargs) @@ -242,9 +253,11 @@ def test_sam_function_without_alias_allows_disabled_deployment_preference(self): # Function, IAM Role self.assertEqual(len(resources), 2) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") @patch.object(SamFunction, "_get_resolved_alias_name") - def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_boolean_parameter(self, get_resolved_alias_name_mock): + def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_boolean_parameter( + self, get_resolved_alias_name_mock + ): alias_name = "AliasName" enabled = {"Ref": "MyEnabledFlag"} deploy_preference_dict = {"Type": "LINEAR", "Enabled": enabled} @@ -255,8 +268,8 @@ def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_boolean_p "Runtime": "nodejs4.3", "Handler": "index.handler", "AutoPublishAlias": alias_name, - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -267,15 +280,14 @@ def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_boolean_p kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock deployment_preference_collection = self._make_deployment_preference_collection() - kwargs['deployment_preference_collection'] = deployment_preference_collection + kwargs["deployment_preference_collection"] = deployment_preference_collection self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = True get_resolved_alias_name_mock.return_value = alias_name resources = sam.to_cloudformation(**kwargs) deployment_preference_collection.update_policy.assert_called_once_with(self.sam_func.logical_id) - deployment_preference_collection.add.assert_called_once_with(self.sam_func.logical_id, - deploy_preference_dict) + deployment_preference_collection.add.assert_called_once_with(self.sam_func.logical_id, deploy_preference_dict) self.intrinsics_resolver_mock.resolve_parameter_refs.assert_any_call(enabled) aliases = [r.to_dict() for r in resources if r.resource_type == LambdaAlias.resource_type] @@ -283,9 +295,11 @@ def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_boolean_p self.assertTrue("UpdatePolicy" in list(aliases[0].values())[0]) self.assertEqual(list(aliases[0].values())[0]["UpdatePolicy"], self.update_policy().to_dict()) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") @patch.object(SamFunction, "_get_resolved_alias_name") - def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_dict_parameter(self, get_resolved_alias_name_mock): + def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_dict_parameter( + self, get_resolved_alias_name_mock + ): alias_name = "AliasName" enabled = {"Ref": "MyEnabledFlag"} deploy_preference_dict = {"Type": "LINEAR", "Enabled": enabled} @@ -296,8 +310,8 @@ def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_dict_para "Runtime": "nodejs4.3", "Handler": "index.handler", "AutoPublishAlias": alias_name, - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -308,16 +322,18 @@ def test_sam_function_with_deployment_preference_intrinsic_ref_enabled_dict_para kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock deployment_preference_collection = self._make_deployment_preference_collection() - kwargs['deployment_preference_collection'] = deployment_preference_collection + kwargs["deployment_preference_collection"] = deployment_preference_collection self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = {"MyEnabledFlag": True} get_resolved_alias_name_mock.return_value = alias_name sam_func.to_cloudformation(**kwargs) - self.assertTrue(sam_func.DeploymentPreference['Enabled']) + self.assertTrue(sam_func.DeploymentPreference["Enabled"]) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') + @patch("boto3.session.Session.region_name", "ap-southeast-1") @patch.object(SamFunction, "_get_resolved_alias_name") - def test_sam_function_with_deployment_preference_intrinsic_findinmap_enabled_dict_parameter(self, get_resolved_alias_name_mock): + def test_sam_function_with_deployment_preference_intrinsic_findinmap_enabled_dict_parameter( + self, get_resolved_alias_name_mock + ): alias_name = "AliasName" enabled = {"Fn::FindInMap": ["FooMap", "FooKey", "Enabled"]} deploy_preference_dict = {"Type": "LINEAR", "Enabled": enabled} @@ -328,8 +344,8 @@ def test_sam_function_with_deployment_preference_intrinsic_findinmap_enabled_dic "Runtime": "nodejs4.3", "Handler": "index.handler", "AutoPublishAlias": alias_name, - "DeploymentPreference": deploy_preference_dict - } + "DeploymentPreference": deploy_preference_dict, + }, } sam_func = SamFunction.from_dict(logical_id="foo", resource_dict=func) @@ -340,13 +356,13 @@ def test_sam_function_with_deployment_preference_intrinsic_findinmap_enabled_dic kwargs["intrinsics_resolver"] = self.intrinsics_resolver_mock kwargs["mappings_resolver"] = self.mappings_resolver_mock deployment_preference_collection = self._make_deployment_preference_collection() - kwargs['deployment_preference_collection'] = deployment_preference_collection + kwargs["deployment_preference_collection"] = deployment_preference_collection self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = {"MyEnabledFlag": True} self.mappings_resolver_mock.resolve_parameter_refs.return_value = True get_resolved_alias_name_mock.return_value = alias_name sam_func.to_cloudformation(**kwargs) - self.assertTrue(sam_func.DeploymentPreference['Enabled']) + self.assertTrue(sam_func.DeploymentPreference["Enabled"]) @patch("samtranslator.translator.logical_id_generator.LogicalIdGenerator") def test_version_creation(self, LogicalIdGeneratorMock): @@ -398,11 +414,7 @@ def test_version_creation_intrinsic_function_in_code_s3key(self, LogicalIdGenera id_val = "SomeLogicalId" generator_mock.gen.return_value = id_val - self.lambda_func.Code = { - "S3Bucket": "bucket", - "S3Key": {"Ref": "keyparameter"}, - "S3ObjectVersion": "version" - } + self.lambda_func.Code = {"S3Bucket": "bucket", "S3Key": {"Ref": "keyparameter"}, "S3ObjectVersion": "version"} self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = self.lambda_func.Code version = self.sam_func._construct_version(self.lambda_func, self.intrinsics_resolver_mock) @@ -418,11 +430,7 @@ def test_version_creation_intrinsic_function_in_code_s3bucket(self, LogicalIdGen id_val = "SomeLogicalId" generator_mock.gen.return_value = id_val - self.lambda_func.Code = { - "S3Bucket": {"Ref": "bucketparameter"}, - "S3Key": "key", - "S3ObjectVersion": "version" - } + self.lambda_func.Code = {"S3Bucket": {"Ref": "bucketparameter"}, "S3Key": "key", "S3ObjectVersion": "version"} self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = self.lambda_func.Code version = self.sam_func._construct_version(self.lambda_func, self.intrinsics_resolver_mock) @@ -438,11 +446,7 @@ def test_version_creation_intrinsic_function_in_code_s3version(self, LogicalIdGe id_val = "SomeLogicalId" generator_mock.gen.return_value = id_val - self.lambda_func.Code = { - "S3Bucket": "bucket", - "S3Key": "key", - "S3ObjectVersion": {"Ref": "versionparameter"} - } + self.lambda_func.Code = {"S3Bucket": "bucket", "S3Key": "key", "S3ObjectVersion": {"Ref": "versionparameter"}} self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = self.lambda_func.Code version = self.sam_func._construct_version(self.lambda_func, self.intrinsics_resolver_mock) @@ -481,12 +485,7 @@ def test_version_logical_id_changes_with_intrinsic_functions(self, LogicalIdGene generator_mock.gen.return_value = id_val prefix = self.sam_func.logical_id + "Version" - self.lambda_func.Code = { - "S3Bucket": "bucket", - "S3Key": { - "Ref": "someparam" - } - } + self.lambda_func.Code = {"S3Bucket": "bucket", "S3Key": {"Ref": "someparam"}} self.intrinsics_resolver_mock.resolve_parameter_refs.return_value = self.lambda_func.Code self.sam_func._construct_version(self.lambda_func, self.intrinsics_resolver_mock) @@ -512,7 +511,6 @@ def test_alias_creation(self): self.assertEqual(alias.FunctionName, {"Ref": self.lambda_func.logical_id}) self.assertEqual(alias.FunctionVersion, {"Fn::GetAtt": [self.lambda_version.logical_id, "Version"]}) - def test_alias_creation_error(self): with self.assertRaises(InvalidResourceException): self.sam_func._construct_alias(None, self.lambda_func, self.lambda_version) @@ -530,8 +528,9 @@ def test_get_resolved_alias_name_must_work(self): def test_get_resolved_alias_name_must_error_if_intrinsics_are_not_resolved(self): property_name = "something" - expected_exception_msg = "Resource with id [{}] is invalid. '{}' must be a string or a Ref to a template parameter"\ - .format(self.sam_func.logical_id, property_name) + expected_exception_msg = "Resource with id [{}] is invalid. '{}' must be a string or a Ref to a template parameter".format( + self.sam_func.logical_id, property_name + ) alias_value = {"Ref": "param1"} # Unresolved @@ -546,8 +545,9 @@ def test_get_resolved_alias_name_must_error_if_intrinsics_are_not_resolved(self) def test_get_resolved_alias_name_must_error_if_intrinsics_are_not_resolved_with_list(self): property_name = "something" - expected_exception_msg = "Resource with id [{}] is invalid. '{}' must be a string or a Ref to a template parameter"\ - .format(self.sam_func.logical_id, property_name) + expected_exception_msg = "Resource with id [{}] is invalid. '{}' must be a string or a Ref to a template parameter".format( + self.sam_func.logical_id, property_name + ) alias_value = ["Ref", "param1"] # Unresolved @@ -561,11 +561,7 @@ def test_get_resolved_alias_name_must_error_if_intrinsics_are_not_resolved_with_ def _make_lambda_function(self, logical_id): func = LambdaFunction(logical_id) - func.Code = { - "S3Bucket": "bucket", - "S3Key": "key", - "S3ObjectVersion": "version" - } + func.Code = {"S3Bucket": "bucket", "S3Key": "key", "S3ObjectVersion": "version"} return func def _make_lambda_version(self, logical_id, func): @@ -584,7 +580,6 @@ def _make_deployment_preference_collection(self): class TestSupportedResourceReferences(TestCase): - def test_must_not_break_support(self): func = SamFunction("LogicalId") diff --git a/tests/translator/test_logical_id_generator.py b/tests/translator/test_logical_id_generator.py index f4356642a4..7205bc6e08 100644 --- a/tests/translator/test_logical_id_generator.py +++ b/tests/translator/test_logical_id_generator.py @@ -5,6 +5,7 @@ from mock import patch from samtranslator.translator.logical_id_generator import LogicalIdGenerator + class TestLogicalIdGenerator(TestCase): """ Test the implementation of LogicalIDGenerator @@ -29,7 +30,7 @@ def test_gen_no_data(self, stringify_mock): @patch.object(LogicalIdGenerator, "_stringify") def test_gen_dict_data(self, stringify_mock, get_hash_mock): data = {"foo": "bar"} - stringified_data = "stringified data" + stringified_data = "stringified data" hash_value = "some hash value" get_hash_mock.return_value = hash_value stringify_mock.return_value = stringified_data @@ -48,7 +49,7 @@ def test_gen_stability_with_copy(self): generator = LogicalIdGenerator(self.prefix, data_obj=data) old = generator.gen() - new = LogicalIdGenerator(self.prefix, data_obj=data.copy()).gen() # Create a copy of data obj + new = LogicalIdGenerator(self.prefix, data_obj=data.copy()).gen() # Create a copy of data obj self.assertEqual(old, new) def test_gen_stability_with_different_dict_ordering(self): @@ -62,7 +63,7 @@ def test_gen_stability_with_different_dict_ordering(self): def test_gen_changes_on_different_dict_data(self): data = {"foo": "bar", "nested": {"a": "b", "c": "d"}} - data_other = {"foo2": "bar", "nested": {"a": "b", "c": "d"}} # Just changing one key + data_other = {"foo2": "bar", "nested": {"a": "b", "c": "d"}} # Just changing one key old = LogicalIdGenerator(self.prefix, data_obj=data) new = LogicalIdGenerator(self.prefix, data_obj=data_other).gen() @@ -119,14 +120,14 @@ def testget_hash_no_data(self, stringify_mock): stringify_mock.assert_called_once_with(data) def test_stringify_basic_objects(self): - data = {"a": "b", "c": [4,3,1]} + data = {"a": "b", "c": [4, 3, 1]} expected = '{"a":"b","c":[4,3,1]}' generator = LogicalIdGenerator(self.prefix, data_obj=data) self.assertEqual(expected, generator._stringify(data)) def test_stringify_basic_objects_sorting(self): - data = {"c": [4,3,1], "a": "b"} + data = {"c": [4, 3, 1], "a": "b"} expected = '{"a":"b","c":[4,3,1]}' generator = LogicalIdGenerator(self.prefix, data_obj=data) @@ -146,7 +147,7 @@ def test_stringify_strings(self): # Strings should be returned unmodified ie. json dump is short circuited self.assertEqual(data, generator._stringify(data)) - @patch.object(json, 'dumps') + @patch.object(json, "dumps") def test_stringify_expectations(self, json_dumps_mock): data = ["foo"] expected = "bar" @@ -155,9 +156,9 @@ def test_stringify_expectations(self, json_dumps_mock): generator = LogicalIdGenerator(self.prefix, data_obj=data) self.assertEqual(expected, generator._stringify(data)) - json_dumps_mock.assert_called_with(data, separators=(',', ':'), sort_keys=True) + json_dumps_mock.assert_called_with(data, separators=(",", ":"), sort_keys=True) - @patch.object(json, 'dumps') + @patch.object(json, "dumps") def test_stringify_expectations_for_string(self, json_dumps_mock): data = "foo" diff --git a/tests/translator/test_managed_policies_translator.py b/tests/translator/test_managed_policies_translator.py index c442bf1948..3337af663e 100644 --- a/tests/translator/test_managed_policies_translator.py +++ b/tests/translator/test_managed_policies_translator.py @@ -3,47 +3,33 @@ def create_page(policies): - return { - 'Policies': map(lambda x: {'PolicyName': x[0], 'Arn': x[1]}, policies) - } + return {"Policies": map(lambda x: {"PolicyName": x[0], "Arn": x[1]}, policies)} + def test_load(): paginator = MagicMock() paginator.paginate.return_value = [ - create_page([ - ('Policy-1', 'Arn-1'), - ('Policy-2', 'Arn-2'), - ('Policy-3', 'Arn-3'), - ('Policy-4', 'Arn-4'), - ]), - create_page([ - ('Policy-a', 'Arn-a'), - ('Policy-b', 'Arn-b'), - ('Policy-c', 'Arn-c'), - ('Policy-d', 'Arn-d'), - ]), - create_page([ - ('Policy-final', 'Arn-final'), - ]), + create_page([("Policy-1", "Arn-1"), ("Policy-2", "Arn-2"), ("Policy-3", "Arn-3"), ("Policy-4", "Arn-4")]), + create_page([("Policy-a", "Arn-a"), ("Policy-b", "Arn-b"), ("Policy-c", "Arn-c"), ("Policy-d", "Arn-d")]), + create_page([("Policy-final", "Arn-final")]), ] - iam = MagicMock() iam.get_paginator.return_value = paginator actual = ManagedPolicyLoader(iam).load() expected = { - 'Policy-1': 'Arn-1', - 'Policy-2': 'Arn-2', - 'Policy-3': 'Arn-3', - 'Policy-4': 'Arn-4', - 'Policy-a': 'Arn-a', - 'Policy-b': 'Arn-b', - 'Policy-c': 'Arn-c', - 'Policy-d': 'Arn-d', - 'Policy-final': 'Arn-final' + "Policy-1": "Arn-1", + "Policy-2": "Arn-2", + "Policy-3": "Arn-3", + "Policy-4": "Arn-4", + "Policy-a": "Arn-a", + "Policy-b": "Arn-b", + "Policy-c": "Arn-c", + "Policy-d": "Arn-d", + "Policy-final": "Arn-final", } assert actual == expected - iam.get_paginator.assert_called_once_with('list_policies') - paginator.paginate.assert_called_once_with(Scope='AWS') + iam.get_paginator.assert_called_once_with("list_policies") + paginator.paginate.assert_called_once_with(Scope="AWS") diff --git a/tests/translator/test_translator.py b/tests/translator/test_translator.py index 6d8dca6cfd..e1480b183a 100644 --- a/tests/translator/test_translator.py +++ b/tests/translator/test_translator.py @@ -24,15 +24,15 @@ from mock import Mock, MagicMock, patch BASE_PATH = os.path.dirname(__file__) -INPUT_FOLDER = BASE_PATH + '/input' -OUTPUT_FOLDER = BASE_PATH + '/output' +INPUT_FOLDER = BASE_PATH + "/input" +OUTPUT_FOLDER = BASE_PATH + "/output" # Do not sort AWS::Serverless::Function Layers Property. # Order of Layers is an important attribute and shouldn't be changed. -DO_NOT_SORT = ['Layers'] +DO_NOT_SORT = ["Layers"] BASE_PATH = os.path.dirname(__file__) -INPUT_FOLDER = os.path.join(BASE_PATH, 'input') -OUTPUT_FOLDER = os.path.join(BASE_PATH, 'output') +INPUT_FOLDER = os.path.join(BASE_PATH, "input") +OUTPUT_FOLDER = os.path.join(BASE_PATH, "output") def deep_sort_lists(value): @@ -92,12 +92,13 @@ def custom_list_data_comparator(obj1, obj2): s1, s2 = type(obj1).__name__, type(obj2).__name__ return (s1 > s2) - (s1 < s2) + def mock_sar_service_call(self, service_call_function, logical_id, *args): """ Current implementation: args[0] is always the application_id """ application_id = args[0] - status = 'ACTIVE' + status = "ACTIVE" if application_id == "no-access": raise InvalidResourceException(logical_id, "Cannot access application: {}.".format(application_id)) elif application_id == "non-existent": @@ -105,7 +106,9 @@ def mock_sar_service_call(self, service_call_function, logical_id, *args): elif application_id == "invalid-semver": raise InvalidResourceException(logical_id, "Cannot access application: {}.".format(application_id)) elif application_id == 1: - raise InvalidResourceException(logical_id, "Type of property 'ApplicationId' is invalid.".format(application_id)) + raise InvalidResourceException( + logical_id, "Type of property 'ApplicationId' is invalid.".format(application_id) + ) elif application_id == "preparing" and self._wait_for_template_active_status < 2: self._wait_for_template_active_status += 1 self.SLEEP_TIME_SECONDS = 0 @@ -119,189 +122,195 @@ def mock_sar_service_call(self, service_call_function, logical_id, *args): elif application_id == "expired": status = "EXPIRED" message = { - 'ApplicationId': args[0], - 'CreationTime': 'x', - 'ExpirationTime': 'x', - 'SemanticVersion': '1.1.1', - 'Status': status, - 'TemplateId': 'id-xx-xx', - 'TemplateUrl': 'https://awsserverlessrepo-changesets-xxx.s3.amazonaws.com/signed-url' + "ApplicationId": args[0], + "CreationTime": "x", + "ExpirationTime": "x", + "SemanticVersion": "1.1.1", + "Status": status, + "TemplateId": "id-xx-xx", + "TemplateUrl": "https://awsserverlessrepo-changesets-xxx.s3.amazonaws.com/signed-url", } return message + # implicit_api, explicit_api, explicit_api_ref, api_cache tests currently have deployment IDs hardcoded in output file. # These ids are generated using sha1 hash of the swagger body for implicit # api and s3 location for explicit api. -class TestTranslatorEndToEnd(TestCase): +class TestTranslatorEndToEnd(TestCase): @parameterized.expand( - itertools.product([ - 'cognito_userpool_with_event', - 's3_with_condition', - 'function_with_condition', - 'basic_function', - 'basic_application', - 'application_preparing_state', - 'application_with_intrinsics', - 'basic_layer', - 'cloudwatchevent', - 'eventbridgerule', - 'eventbridgerule_schedule_properties', - 'cloudwatch_logs_with_ref', - 'cloudwatchlog', - 'streams', - 'sqs', - 'simpletable', - 'simpletable_with_sse', - 'implicit_api', - 'explicit_api', - 'api_endpoint_configuration', - 'api_with_auth_all_maximum', - 'api_with_auth_all_minimum', - 'api_with_auth_no_default', - 'api_with_auth_with_default_scopes', - 'api_with_auth_with_default_scopes_openapi', - 'api_with_default_aws_iam_auth', - 'api_with_default_aws_iam_auth_and_no_auth_route', - 'api_with_method_aws_iam_auth', - 'api_with_aws_iam_auth_overrides', - 'api_with_method_settings', - 'api_with_binary_media_types', - 'api_with_binary_media_types_definition_body', - 'api_with_minimum_compression_size', - 'api_with_resource_refs', - 'api_with_cors', - 'api_with_cors_and_auth_no_preflight_auth', - 'api_with_cors_and_auth_preflight_auth', - 'api_with_cors_and_only_methods', - 'api_with_cors_and_only_headers', - 'api_with_cors_and_only_origins', - 'api_with_cors_and_only_maxage', - 'api_with_cors_and_only_credentials_false', - 'api_with_cors_no_definitionbody', - 'api_with_incompatible_stage_name', - 'api_with_gateway_responses', - 'api_with_gateway_responses_all', - 'api_with_gateway_responses_minimal', - 'api_with_gateway_responses_implicit', - 'api_with_gateway_responses_string_status_code', - 'api_cache', - 'api_with_access_log_setting', - 'api_with_canary_setting', - 'api_with_xray_tracing', - 'api_request_model', - 'api_with_stage_tags', - 's3', - 's3_create_remove', - 's3_existing_lambda_notification_configuration', - 's3_existing_other_notification_configuration', - 's3_filter', - 's3_multiple_events_same_bucket', - 's3_multiple_functions', - 's3_with_dependsOn', - 'sns', - 'sns_sqs', - 'sns_existing_sqs', - 'sns_outside_sqs', - 'sns_existing_other_subscription', - 'sns_topic_outside_template', - 'alexa_skill', - 'alexa_skill_with_skill_id', - 'iot_rule', - 'layers_with_intrinsics', - 'layers_all_properties', - 'function_managed_inline_policy', - 'unsupported_resources', - 'intrinsic_functions', - 'basic_function_with_tags', - 'depends_on', - 'function_event_conditions', - 'function_with_dlq', - 'function_with_kmskeyarn', - 'function_with_alias', - 'function_with_alias_intrinsics', - 'function_with_custom_codedeploy_deployment_preference', - 'function_with_custom_conditional_codedeploy_deployment_preference', - 'function_with_disabled_deployment_preference', - 'function_with_deployment_preference', - 'function_with_deployment_preference_all_parameters', - 'function_with_deployment_preference_multiple_combinations', - 'function_with_alias_and_event_sources', - 'function_with_resource_refs', - 'function_with_deployment_and_custom_role', - 'function_with_deployment_no_service_role', - 'function_with_global_layers', - 'function_with_layers', - 'function_with_many_layers', - 'function_with_permissions_boundary', - 'function_with_policy_templates', - 'function_with_sns_event_source_all_parameters', - 'function_with_conditional_managed_policy', - 'function_with_conditional_managed_policy_and_ref_no_value', - 'function_with_conditional_policy_template', - 'function_with_conditional_policy_template_and_ref_no_value', - 'function_with_request_parameters', - 'global_handle_path_level_parameter', - 'globals_for_function', - 'globals_for_api', - 'globals_for_simpletable', - 'all_policy_templates', - 'simple_table_ref_parameter_intrinsic', - 'simple_table_with_table_name', - 'function_concurrency', - 'simple_table_with_extra_tags', - 'explicit_api_with_invalid_events_config', - 'no_implicit_api_with_serverless_rest_api_resource', - 'implicit_api_with_serverless_rest_api_resource', - 'implicit_api_with_auth_and_conditions_max', - 'implicit_api_with_many_conditions', - 'implicit_and_explicit_api_with_conditions', - 'api_with_cors_and_conditions_no_definitionbody', - 'api_with_auth_and_conditions_all_max', - 'api_with_apikey_default_override', - 'api_with_apikey_required', - 'api_with_path_parameters', - 'function_with_event_source_mapping', - 'api_with_swagger_authorizer_none', - 'function_with_event_dest', - 'function_with_event_dest_basic', - 'function_with_event_dest_conditional' - ], - [ - ("aws", "ap-southeast-1"), - ("aws-cn", "cn-north-1"), - ("aws-us-gov", "us-gov-west-1") - ] # Run all the above tests against each of the list of partitions to test against - ) + itertools.product( + [ + "cognito_userpool_with_event", + "s3_with_condition", + "function_with_condition", + "basic_function", + "basic_application", + "application_preparing_state", + "application_with_intrinsics", + "basic_layer", + "cloudwatchevent", + "eventbridgerule", + "eventbridgerule_schedule_properties", + "cloudwatch_logs_with_ref", + "cloudwatchlog", + "streams", + "sqs", + "simpletable", + "simpletable_with_sse", + "implicit_api", + "explicit_api", + "api_endpoint_configuration", + "api_with_auth_all_maximum", + "api_with_auth_all_minimum", + "api_with_auth_no_default", + "api_with_auth_with_default_scopes", + "api_with_auth_with_default_scopes_openapi", + "api_with_default_aws_iam_auth", + "api_with_default_aws_iam_auth_and_no_auth_route", + "api_with_method_aws_iam_auth", + "api_with_aws_iam_auth_overrides", + "api_with_method_settings", + "api_with_binary_media_types", + "api_with_binary_media_types_definition_body", + "api_with_minimum_compression_size", + "api_with_resource_refs", + "api_with_cors", + "api_with_cors_and_auth_no_preflight_auth", + "api_with_cors_and_auth_preflight_auth", + "api_with_cors_and_only_methods", + "api_with_cors_and_only_headers", + "api_with_cors_and_only_origins", + "api_with_cors_and_only_maxage", + "api_with_cors_and_only_credentials_false", + "api_with_cors_no_definitionbody", + "api_with_incompatible_stage_name", + "api_with_gateway_responses", + "api_with_gateway_responses_all", + "api_with_gateway_responses_minimal", + "api_with_gateway_responses_implicit", + "api_with_gateway_responses_string_status_code", + "api_cache", + "api_with_access_log_setting", + "api_with_canary_setting", + "api_with_xray_tracing", + "api_request_model", + "api_with_stage_tags", + "s3", + "s3_create_remove", + "s3_existing_lambda_notification_configuration", + "s3_existing_other_notification_configuration", + "s3_filter", + "s3_multiple_events_same_bucket", + "s3_multiple_functions", + "s3_with_dependsOn", + "sns", + "sns_sqs", + "sns_existing_sqs", + "sns_outside_sqs", + "sns_existing_other_subscription", + "sns_topic_outside_template", + "alexa_skill", + "alexa_skill_with_skill_id", + "iot_rule", + "layers_with_intrinsics", + "layers_all_properties", + "function_managed_inline_policy", + "unsupported_resources", + "intrinsic_functions", + "basic_function_with_tags", + "depends_on", + "function_event_conditions", + "function_with_dlq", + "function_with_kmskeyarn", + "function_with_alias", + "function_with_alias_intrinsics", + "function_with_custom_codedeploy_deployment_preference", + "function_with_custom_conditional_codedeploy_deployment_preference", + "function_with_disabled_deployment_preference", + "function_with_deployment_preference", + "function_with_deployment_preference_all_parameters", + "function_with_deployment_preference_multiple_combinations", + "function_with_alias_and_event_sources", + "function_with_resource_refs", + "function_with_deployment_and_custom_role", + "function_with_deployment_no_service_role", + "function_with_global_layers", + "function_with_layers", + "function_with_many_layers", + "function_with_permissions_boundary", + "function_with_policy_templates", + "function_with_sns_event_source_all_parameters", + "function_with_conditional_managed_policy", + "function_with_conditional_managed_policy_and_ref_no_value", + "function_with_conditional_policy_template", + "function_with_conditional_policy_template_and_ref_no_value", + "function_with_request_parameters", + "global_handle_path_level_parameter", + "globals_for_function", + "globals_for_api", + "globals_for_simpletable", + "all_policy_templates", + "simple_table_ref_parameter_intrinsic", + "simple_table_with_table_name", + "function_concurrency", + "simple_table_with_extra_tags", + "explicit_api_with_invalid_events_config", + "no_implicit_api_with_serverless_rest_api_resource", + "implicit_api_with_serverless_rest_api_resource", + "implicit_api_with_auth_and_conditions_max", + "implicit_api_with_many_conditions", + "implicit_and_explicit_api_with_conditions", + "api_with_cors_and_conditions_no_definitionbody", + "api_with_auth_and_conditions_all_max", + "api_with_apikey_default_override", + "api_with_apikey_required", + "api_with_path_parameters", + "function_with_event_source_mapping", + "api_with_swagger_authorizer_none", + "function_with_event_dest", + "function_with_event_dest_basic", + "function_with_event_dest_conditional", + ], + [ + ("aws", "ap-southeast-1"), + ("aws-cn", "cn-north-1"), + ("aws-us-gov", "us-gov-west-1"), + ], # Run all the above tests against each of the list of partitions to test against + ) + ) + @patch( + "samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call", + mock_sar_service_call, ) - @patch('samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call', mock_sar_service_call) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_transform_success(self, testcase, partition_with_region): partition = partition_with_region[0] region = partition_with_region[1] - manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + '.yaml'), 'r')) + manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r")) # To uncover unicode-related bugs, convert dict to JSON string and parse JSON back to dict manifest = json.loads(json.dumps(manifest)) partition_folder = partition if partition != "aws" else "" - expected_filepath = os.path.join(OUTPUT_FOLDER, partition_folder, testcase + '.json') - expected = json.load(open(expected_filepath, 'r')) + expected_filepath = os.path.join(OUTPUT_FOLDER, partition_folder, testcase + ".json") + expected = json.load(open(expected_filepath, "r")) - with patch('boto3.session.Session.region_name', region): + with patch("boto3.session.Session.region_name", region): parameter_values = get_template_parameter_values() mock_policy_loader = MagicMock() mock_policy_loader.load.return_value = { - 'AWSLambdaBasicExecutionRole': 'arn:{}:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'.format(partition), - 'AmazonDynamoDBFullAccess': 'arn:{}:iam::aws:policy/AmazonDynamoDBFullAccess'.format(partition), - 'AmazonDynamoDBReadOnlyAccess': 'arn:{}:iam::aws:policy/AmazonDynamoDBReadOnlyAccess'.format(partition), - 'AWSLambdaRole': 'arn:{}:iam::aws:policy/service-role/AWSLambdaRole'.format(partition), + "AWSLambdaBasicExecutionRole": "arn:{}:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole".format( + partition + ), + "AmazonDynamoDBFullAccess": "arn:{}:iam::aws:policy/AmazonDynamoDBFullAccess".format(partition), + "AmazonDynamoDBReadOnlyAccess": "arn:{}:iam::aws:policy/AmazonDynamoDBReadOnlyAccess".format(partition), + "AWSLambdaRole": "arn:{}:iam::aws:policy/service-role/AWSLambdaRole".format(partition), } - output_fragment = transform( - manifest, parameter_values, mock_policy_loader) + output_fragment = transform(manifest, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + print (json.dumps(output_fragment, indent=2)) # Only update the deployment Logical Id hash in Py3. if sys.version_info.major >= 3: @@ -311,65 +320,70 @@ def test_transform_success(self, testcase, partition_with_region): assert deep_sort_lists(output_fragment) == deep_sort_lists(expected) @parameterized.expand( - itertools.product([ - 'explicit_api_openapi_3', - 'api_with_auth_all_maximum_openapi_3', - 'api_with_cors_openapi_3', - 'api_with_gateway_responses_all_openapi_3', - 'api_with_open_api_version', - 'api_with_open_api_version_2', - 'api_with_auth_all_minimum_openapi', - 'api_with_swagger_and_openapi_with_auth', - 'api_with_openapi_definition_body_no_flag', - 'api_request_model_openapi_3', - 'api_with_apikey_required_openapi_3', - 'api_with_basic_custom_domain', - 'api_with_basic_custom_domain_intrinsics', - 'api_with_custom_domain_route53', - 'implicit_http_api', - 'explicit_http_api_minimum', - 'implicit_http_api_auth_and_simple_case', - 'http_api_existing_openapi', - 'http_api_existing_openapi_conditions', - 'implicit_http_api_with_many_conditions', - 'http_api_explicit_stage', - 'http_api_def_uri', - 'explicit_http_api' - ], - [ - ("aws", "ap-southeast-1"), - ("aws-cn", "cn-north-1"), - ("aws-us-gov", "us-gov-west-1") - ] # Run all the above tests against each of the list of partitions to test against - ) + itertools.product( + [ + "explicit_api_openapi_3", + "api_with_auth_all_maximum_openapi_3", + "api_with_cors_openapi_3", + "api_with_gateway_responses_all_openapi_3", + "api_with_open_api_version", + "api_with_open_api_version_2", + "api_with_auth_all_minimum_openapi", + "api_with_swagger_and_openapi_with_auth", + "api_with_openapi_definition_body_no_flag", + "api_request_model_openapi_3", + "api_with_apikey_required_openapi_3", + "api_with_basic_custom_domain", + "api_with_basic_custom_domain_intrinsics", + "api_with_custom_domain_route53", + "implicit_http_api", + "explicit_http_api_minimum", + "implicit_http_api_auth_and_simple_case", + "http_api_existing_openapi", + "http_api_existing_openapi_conditions", + "implicit_http_api_with_many_conditions", + "http_api_explicit_stage", + "http_api_def_uri", + "explicit_http_api", + ], + [ + ("aws", "ap-southeast-1"), + ("aws-cn", "cn-north-1"), + ("aws-us-gov", "us-gov-west-1"), + ], # Run all the above tests against each of the list of partitions to test against + ) + ) + @patch( + "samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call", + mock_sar_service_call, ) - @patch('samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call', mock_sar_service_call) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_transform_success_openapi3(self, testcase, partition_with_region): partition = partition_with_region[0] region = partition_with_region[1] - manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + '.yaml'), 'r')) + manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r")) # To uncover unicode-related bugs, convert dict to JSON string and parse JSON back to dict manifest = json.loads(json.dumps(manifest)) partition_folder = partition if partition != "aws" else "" - expected_filepath = os.path.join(OUTPUT_FOLDER, partition_folder, testcase + '.json') - expected = json.load(open(expected_filepath, 'r')) + expected_filepath = os.path.join(OUTPUT_FOLDER, partition_folder, testcase + ".json") + expected = json.load(open(expected_filepath, "r")) - with patch('boto3.session.Session.region_name', region): + with patch("boto3.session.Session.region_name", region): parameter_values = get_template_parameter_values() mock_policy_loader = MagicMock() mock_policy_loader.load.return_value = { - 'AWSLambdaBasicExecutionRole': 'arn:{}:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'.format(partition), - 'AmazonDynamoDBFullAccess': 'arn:{}:iam::aws:policy/AmazonDynamoDBFullAccess'.format(partition), - 'AmazonDynamoDBReadOnlyAccess': 'arn:{}:iam::aws:policy/AmazonDynamoDBReadOnlyAccess'.format(partition), - 'AWSLambdaRole': 'arn:{}:iam::aws:policy/service-role/AWSLambdaRole'.format(partition), + "AWSLambdaBasicExecutionRole": "arn:{}:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole".format( + partition + ), + "AmazonDynamoDBFullAccess": "arn:{}:iam::aws:policy/AmazonDynamoDBFullAccess".format(partition), + "AmazonDynamoDBReadOnlyAccess": "arn:{}:iam::aws:policy/AmazonDynamoDBReadOnlyAccess".format(partition), + "AWSLambdaRole": "arn:{}:iam::aws:policy/service-role/AWSLambdaRole".format(partition), } - output_fragment = transform( - manifest, parameter_values, mock_policy_loader) + output_fragment = transform(manifest, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + print (json.dumps(output_fragment, indent=2)) # Only update the deployment Logical Id hash in Py3. if sys.version_info.major >= 3: @@ -379,50 +393,55 @@ def test_transform_success_openapi3(self, testcase, partition_with_region): assert deep_sort_lists(output_fragment) == deep_sort_lists(expected) @parameterized.expand( - itertools.product([ - 'api_with_aws_account_whitelist', - 'api_with_aws_account_blacklist', - 'api_with_ip_range_whitelist', - 'api_with_ip_range_blacklist', - 'api_with_source_vpc_whitelist', - 'api_with_source_vpc_blacklist', - 'api_with_resource_policy', - 'api_with_resource_policy_global', - 'api_with_resource_policy_global_implicit' - ], - [ - ("aws", "ap-southeast-1"), - ("aws-cn", "cn-north-1"), - ("aws-us-gov", "us-gov-west-1") - ] # Run all the above tests against each of the list of partitions to test against - ) + itertools.product( + [ + "api_with_aws_account_whitelist", + "api_with_aws_account_blacklist", + "api_with_ip_range_whitelist", + "api_with_ip_range_blacklist", + "api_with_source_vpc_whitelist", + "api_with_source_vpc_blacklist", + "api_with_resource_policy", + "api_with_resource_policy_global", + "api_with_resource_policy_global_implicit", + ], + [ + ("aws", "ap-southeast-1"), + ("aws-cn", "cn-north-1"), + ("aws-us-gov", "us-gov-west-1"), + ], # Run all the above tests against each of the list of partitions to test against + ) ) - @patch('samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call', mock_sar_service_call) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch( + "samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call", + mock_sar_service_call, + ) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_transform_success_resource_policy(self, testcase, partition_with_region): partition = partition_with_region[0] region = partition_with_region[1] - manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + '.yaml'), 'r')) + manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r")) # To uncover unicode-related bugs, convert dict to JSON string and parse JSON back to dict manifest = json.loads(json.dumps(manifest)) partition_folder = partition if partition != "aws" else "" - expected_filepath = os.path.join(OUTPUT_FOLDER, partition_folder, testcase + '.json') - expected = json.load(open(expected_filepath, 'r')) + expected_filepath = os.path.join(OUTPUT_FOLDER, partition_folder, testcase + ".json") + expected = json.load(open(expected_filepath, "r")) - with patch('boto3.session.Session.region_name', region): + with patch("boto3.session.Session.region_name", region): parameter_values = get_template_parameter_values() mock_policy_loader = MagicMock() mock_policy_loader.load.return_value = { - 'AWSLambdaBasicExecutionRole': 'arn:{}:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'.format(partition), - 'AmazonDynamoDBFullAccess': 'arn:{}:iam::aws:policy/AmazonDynamoDBFullAccess'.format(partition), - 'AmazonDynamoDBReadOnlyAccess': 'arn:{}:iam::aws:policy/AmazonDynamoDBReadOnlyAccess'.format(partition), - 'AWSLambdaRole': 'arn:{}:iam::aws:policy/service-role/AWSLambdaRole'.format(partition), + "AWSLambdaBasicExecutionRole": "arn:{}:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole".format( + partition + ), + "AmazonDynamoDBFullAccess": "arn:{}:iam::aws:policy/AmazonDynamoDBFullAccess".format(partition), + "AmazonDynamoDBReadOnlyAccess": "arn:{}:iam::aws:policy/AmazonDynamoDBReadOnlyAccess".format(partition), + "AWSLambdaRole": "arn:{}:iam::aws:policy/service-role/AWSLambdaRole".format(partition), } - output_fragment = transform( - manifest, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + output_fragment = transform(manifest, parameter_values, mock_policy_loader) + print (json.dumps(output_fragment, indent=2)) # Only update the deployment Logical Id hash in Py3. if sys.version_info.major >= 3: @@ -444,12 +463,14 @@ def _update_logical_id_hash(self, resources): if "AWS::ApiGateway::RestApi" == resource_dict.get("Type"): resource_properties = resource_dict.get("Properties", {}) if "Body" in resource_properties: - self._generate_new_deployment_hash(logical_id, resource_properties.get("Body"), rest_api_to_swagger_hash) + self._generate_new_deployment_hash( + logical_id, resource_properties.get("Body"), rest_api_to_swagger_hash + ) elif "BodyS3Location" in resource_dict.get("Properties"): - self._generate_new_deployment_hash(logical_id, - resource_properties.get("BodyS3Location"), - rest_api_to_swagger_hash) + self._generate_new_deployment_hash( + logical_id, resource_properties.get("BodyS3Location"), rest_api_to_swagger_hash + ) # Collect all APIGW Deployments LogicalIds and generate the new ones for logical_id, resource_dict in output_resources.items(): @@ -460,7 +481,7 @@ def _update_logical_id_hash(self, resources): data_hash = rest_api_to_swagger_hash.get(rest_id) - description = resource_properties.get("Description")[:-len(data_hash)] + description = resource_properties.get("Description")[: -len(data_hash)] resource_properties["Description"] = description + data_hash @@ -495,97 +516,103 @@ def _update_logical_id_hash(self, resources): output_value["Value"]["Ref"] = deployment_logical_id_dict[output_value.get("Value").get("Ref")] def _generate_new_deployment_hash(self, logical_id, dict_to_hash, rest_api_to_swagger_hash): - data_bytes = json.dumps(dict_to_hash, separators=(',', ':'), sort_keys=True).encode("utf8") + data_bytes = json.dumps(dict_to_hash, separators=(",", ":"), sort_keys=True).encode("utf8") data_hash = hashlib.sha1(data_bytes).hexdigest() rest_api_to_swagger_hash[logical_id] = data_hash -@pytest.mark.parametrize('testcase', [ - 'error_cognito_userpool_duplicate_trigger', - 'error_api_duplicate_methods_same_path', - 'error_api_gateway_responses_nonnumeric_status_code', - 'error_api_gateway_responses_unknown_responseparameter', - 'error_api_gateway_responses_unknown_responseparameter_property', - 'error_api_invalid_auth', - 'error_api_invalid_path', - 'error_api_invalid_definitionuri', - 'error_api_invalid_definitionbody', - 'error_api_invalid_stagename', - 'error_api_with_invalid_open_api_version', - 'error_api_invalid_restapiid', - 'error_api_invalid_request_model', - 'error_application_properties', - 'error_application_does_not_exist', - 'error_application_no_access', - 'error_application_preparing_timeout', - 'error_cors_on_external_swagger', - 'error_invalid_cors_dict', - 'error_invalid_findinmap', - 'error_invalid_getatt', - 'error_cors_credentials_true_with_wildcard_origin', - 'error_cors_credentials_true_without_explicit_origin', - 'error_function_invalid_codeuri', - 'error_function_invalid_api_event', - 'error_function_invalid_autopublishalias', - 'error_function_invalid_event_type', - 'error_function_invalid_layer', - 'error_function_no_codeuri', - 'error_function_no_handler', - 'error_function_no_runtime', - 'error_function_with_deployment_preference_missing_alias', - 'error_function_with_invalid_deployment_preference_hook_property', - 'error_function_invalid_request_parameters', - 'error_invalid_logical_id', - 'error_layer_invalid_properties', - 'error_missing_queue', - 'error_missing_startingposition', - 'error_missing_stream', - 'error_multiple_resource_errors', - 'error_null_application_id', - 'error_s3_not_in_template', - 'error_table_invalid_attributetype', - 'error_table_primary_key_missing_name', - 'error_table_primary_key_missing_type', - 'error_invalid_resource_parameters', - 'error_reserved_sam_tag', - 'error_existing_event_logical_id', - 'error_existing_permission_logical_id', - 'error_existing_role_logical_id', - 'error_invalid_template', - 'error_resource_not_dict', - 'error_resource_properties_not_dict', - 'error_globals_is_not_dict', - 'error_globals_unsupported_type', - 'error_globals_unsupported_property', - 'error_globals_api_with_stage_name', - 'error_function_policy_template_with_missing_parameter', - 'error_function_policy_template_invalid_value', - 'error_function_with_unknown_policy_template', - 'error_function_with_invalid_policy_statement', - 'error_function_with_invalid_condition_name', - 'error_invalid_document_empty_semantic_version', - 'error_api_with_invalid_open_api_version_type', - 'error_api_with_custom_domains_invalid', - 'error_api_with_custom_domains_route53_invalid', - 'error_api_event_import_vaule_reference', - 'error_function_with_method_auth_and_no_api_auth', - 'error_function_with_no_alias_provisioned_concurrency', - 'error_http_api_def_body_uri', - 'error_http_api_event_invalid_api', - 'error_http_api_invalid_auth', - 'error_http_api_invalid_openapi', - 'error_implicit_http_api_method', - 'error_implicit_http_api_path', - 'error_http_api_event_multiple_same_path', - 'error_function_with_event_dest_invalid', - 'error_function_with_event_dest_type' - -]) -@patch('boto3.session.Session.region_name', 'ap-southeast-1') -@patch('samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call', mock_sar_service_call) -@patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + +@pytest.mark.parametrize( + "testcase", + [ + "error_cognito_userpool_duplicate_trigger", + "error_api_duplicate_methods_same_path", + "error_api_gateway_responses_nonnumeric_status_code", + "error_api_gateway_responses_unknown_responseparameter", + "error_api_gateway_responses_unknown_responseparameter_property", + "error_api_invalid_auth", + "error_api_invalid_path", + "error_api_invalid_definitionuri", + "error_api_invalid_definitionbody", + "error_api_invalid_stagename", + "error_api_with_invalid_open_api_version", + "error_api_invalid_restapiid", + "error_api_invalid_request_model", + "error_application_properties", + "error_application_does_not_exist", + "error_application_no_access", + "error_application_preparing_timeout", + "error_cors_on_external_swagger", + "error_invalid_cors_dict", + "error_invalid_findinmap", + "error_invalid_getatt", + "error_cors_credentials_true_with_wildcard_origin", + "error_cors_credentials_true_without_explicit_origin", + "error_function_invalid_codeuri", + "error_function_invalid_api_event", + "error_function_invalid_autopublishalias", + "error_function_invalid_event_type", + "error_function_invalid_layer", + "error_function_no_codeuri", + "error_function_no_handler", + "error_function_no_runtime", + "error_function_with_deployment_preference_missing_alias", + "error_function_with_invalid_deployment_preference_hook_property", + "error_function_invalid_request_parameters", + "error_invalid_logical_id", + "error_layer_invalid_properties", + "error_missing_queue", + "error_missing_startingposition", + "error_missing_stream", + "error_multiple_resource_errors", + "error_null_application_id", + "error_s3_not_in_template", + "error_table_invalid_attributetype", + "error_table_primary_key_missing_name", + "error_table_primary_key_missing_type", + "error_invalid_resource_parameters", + "error_reserved_sam_tag", + "error_existing_event_logical_id", + "error_existing_permission_logical_id", + "error_existing_role_logical_id", + "error_invalid_template", + "error_resource_not_dict", + "error_resource_properties_not_dict", + "error_globals_is_not_dict", + "error_globals_unsupported_type", + "error_globals_unsupported_property", + "error_globals_api_with_stage_name", + "error_function_policy_template_with_missing_parameter", + "error_function_policy_template_invalid_value", + "error_function_with_unknown_policy_template", + "error_function_with_invalid_policy_statement", + "error_function_with_invalid_condition_name", + "error_invalid_document_empty_semantic_version", + "error_api_with_invalid_open_api_version_type", + "error_api_with_custom_domains_invalid", + "error_api_with_custom_domains_route53_invalid", + "error_api_event_import_vaule_reference", + "error_function_with_method_auth_and_no_api_auth", + "error_function_with_no_alias_provisioned_concurrency", + "error_http_api_def_body_uri", + "error_http_api_event_invalid_api", + "error_http_api_invalid_auth", + "error_http_api_invalid_openapi", + "error_implicit_http_api_method", + "error_implicit_http_api_path", + "error_http_api_event_multiple_same_path", + "error_function_with_event_dest_invalid", + "error_function_with_event_dest_type", + ], +) +@patch("boto3.session.Session.region_name", "ap-southeast-1") +@patch( + "samtranslator.plugins.application.serverless_app_plugin.ServerlessAppPlugin._sar_service_call", + mock_sar_service_call, +) +@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_transform_invalid_document(testcase): - manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + '.yaml'), 'r')) - expected = json.load(open(os.path.join(OUTPUT_FOLDER, testcase + '.json'), 'r')) + manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r")) + expected = json.load(open(os.path.join(OUTPUT_FOLDER, testcase + ".json"), "r")) mock_policy_loader = MagicMock() parameter_values = get_template_parameter_values() @@ -595,24 +622,25 @@ def test_transform_invalid_document(testcase): error_message = get_exception_error_message(e) - assert error_message == expected.get('errorMessage') + assert error_message == expected.get("errorMessage") + -@patch('boto3.session.Session.region_name', 'ap-southeast-1') -@patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) +@patch("boto3.session.Session.region_name", "ap-southeast-1") +@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_transform_unhandled_failure_empty_managed_policy_map(): document = { - 'Transform': 'AWS::Serverless-2016-10-31', - 'Resources': { - 'Resource': { - 'Type': 'AWS::Serverless::Function', - 'Properties': { - 'CodeUri': 's3://bucket/key', - 'Handler': 'index.handler', - 'Runtime': 'nodejs4.3', - 'Policies': 'AmazonS3FullAccess' - } + "Transform": "AWS::Serverless-2016-10-31", + "Resources": { + "Resource": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "s3://bucket/key", + "Handler": "index.handler", + "Runtime": "nodejs4.3", + "Policies": "AmazonS3FullAccess", + }, } - } + }, } parameter_values = get_template_parameter_values() @@ -624,55 +652,49 @@ def test_transform_unhandled_failure_empty_managed_policy_map(): error_message = str(e.value) - assert error_message == 'Managed policy map is empty, but should not be.' + assert error_message == "Managed policy map is empty, but should not be." def assert_metric_call(mock, transform, transform_failure=0, invalid_document=0): - metric_dimensions = [ - { - 'Name': 'Transform', - 'Value': transform - } - ] + metric_dimensions = [{"Name": "Transform", "Value": transform}] mock.put_metric_data.assert_called_once_with( - Namespace='ServerlessTransform', + Namespace="ServerlessTransform", MetricData=[ { - 'MetricName': 'TransformFailure', - 'Value': transform_failure, - 'Unit': 'Count', - 'Dimensions': metric_dimensions + "MetricName": "TransformFailure", + "Value": transform_failure, + "Unit": "Count", + "Dimensions": metric_dimensions, }, { - 'MetricName': 'InvalidDocument', - 'Value': invalid_document, - 'Unit': 'Count', - 'Dimensions': metric_dimensions - } - ] + "MetricName": "InvalidDocument", + "Value": invalid_document, + "Unit": "Count", + "Dimensions": metric_dimensions, + }, + ], ) -@patch('boto3.session.Session.region_name', 'ap-southeast-1') -@patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) +@patch("boto3.session.Session.region_name", "ap-southeast-1") +@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_swagger_body_sha_gets_recomputed(): document = { - 'Transform': 'AWS::Serverless-2016-10-31', - 'Resources': { - 'Resource': { - 'Type': 'AWS::Serverless::Api', - 'Properties': { + "Transform": "AWS::Serverless-2016-10-31", + "Resources": { + "Resource": { + "Type": "AWS::Serverless::Api", + "Properties": { "StageName": "Prod", "DefinitionBody": { # Some body property will do "a": "b" - } - } + }, + }, } - } - + }, } mock_policy_loader = get_policy_mock() @@ -680,7 +702,7 @@ def test_swagger_body_sha_gets_recomputed(): output_fragment = transform(document, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + print (json.dumps(output_fragment, indent=2)) deployment_key = get_deployment_key(output_fragment) assert deployment_key @@ -697,22 +719,18 @@ def test_swagger_body_sha_gets_recomputed(): assert get_deployment_key(output_fragment) == deployment_key_changed -@patch('boto3.session.Session.region_name', 'ap-southeast-1') -@patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) +@patch("boto3.session.Session.region_name", "ap-southeast-1") +@patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_swagger_definitionuri_sha_gets_recomputed(): document = { - 'Transform': 'AWS::Serverless-2016-10-31', - 'Resources': { - 'Resource': { - 'Type': 'AWS::Serverless::Api', - 'Properties': { - "StageName": "Prod", - "DefinitionUri": "s3://bucket/key" - } + "Transform": "AWS::Serverless-2016-10-31", + "Resources": { + "Resource": { + "Type": "AWS::Serverless::Api", + "Properties": {"StageName": "Prod", "DefinitionUri": "s3://bucket/key"}, } - } - + }, } mock_policy_loader = get_policy_mock() @@ -720,7 +738,7 @@ def test_swagger_definitionuri_sha_gets_recomputed(): output_fragment = transform(document, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + print (json.dumps(output_fragment, indent=2)) deployment_key = get_deployment_key(output_fragment) assert deployment_key @@ -736,6 +754,7 @@ def test_swagger_definitionuri_sha_gets_recomputed(): output_fragment = transform(document, parameter_values, mock_policy_loader) assert get_deployment_key(output_fragment) == deployment_key_changed + class TestFunctionVersionWithParameterReferences(TestCase): """ Test how Lambda Function Version gets created when intrinsic functions @@ -743,29 +762,24 @@ class TestFunctionVersionWithParameterReferences(TestCase): def setUp(self): self.document = { - 'Transform': 'AWS::Serverless-2016-10-31', - 'Resources': { - 'MyFunction': { - 'Type': 'AWS::Serverless::Function', - 'Properties': { + "Transform": "AWS::Serverless-2016-10-31", + "Resources": { + "MyFunction": { + "Type": "AWS::Serverless::Function", + "Properties": { "Runtime": "nodejs4.3", "Handler": "index.handler", - "CodeUri": { - "Bucket": {"Ref": "SomeBucket"}, - "Key": {"Ref": "CodeKeyParam"} - }, - "AutoPublishAlias": "live" - } + "CodeUri": {"Bucket": {"Ref": "SomeBucket"}, "Key": {"Ref": "CodeKeyParam"}}, + "AutoPublishAlias": "live", + }, } - } + }, } - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_logical_id_change_with_parameters(self): - parameter_values = { - 'CodeKeyParam': 'value1' - } + parameter_values = {"CodeKeyParam": "value1"} first_transformed_template = self._do_transform(self.document, parameter_values) parameter_values["CodeKeyParam"] = "value2" @@ -776,12 +790,10 @@ def test_logical_id_change_with_parameters(self): assert first_version_id != second_version_id - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_logical_id_remains_same_without_parameter_change(self): - parameter_values = { - 'CodeKeyParam': 'value1' - } + parameter_values = {"CodeKeyParam": "value1"} first_transformed_template = self._do_transform(self.document, parameter_values) second_transformed_template = self._do_transform(self.document, parameter_values) @@ -791,8 +803,8 @@ def test_logical_id_remains_same_without_parameter_change(self): assert first_version_id == second_version_id - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_logical_id_without_resolving_reference(self): # Now value of `CodeKeyParam` is not present in document @@ -808,71 +820,58 @@ def _do_transform(self, document, parameter_values=get_template_parameter_values mock_policy_loader = get_policy_mock() output_fragment = transform(document, parameter_values, mock_policy_loader) - print(json.dumps(output_fragment, indent=2)) + print (json.dumps(output_fragment, indent=2)) return output_fragment class TestTemplateValidation(TestCase): - - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_throws_when_resource_not_found(self): - template = { - "foo": "bar" - } + template = {"foo": "bar"} with self.assertRaises(InvalidDocumentException): sam_parser = Parser() translator = Translator({}, sam_parser) translator.translate(template, {}) - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_throws_when_resource_is_empty(self): - template = { - "Resources": {} - } + template = {"Resources": {}} with self.assertRaises(InvalidDocumentException): sam_parser = Parser() translator = Translator({}, sam_parser) translator.translate(template, {}) - - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_throws_when_resource_is_not_dict(self): - template = { - "Resources": [1,2,3] - } + template = {"Resources": [1, 2, 3]} with self.assertRaises(InvalidDocumentException): sam_parser = Parser() translator = Translator({}, sam_parser) translator.translate(template, {}) - - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("boto3.session.Session.region_name", "ap-southeast-1") + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_throws_when_resources_not_all_dicts(self): - template = { - "Resources": { - "notadict": None, - "MyResource": {} - } - } + template = {"Resources": {"notadict": None, "MyResource": {}}} with self.assertRaises(InvalidDocumentException): sam_parser = Parser() translator = Translator({}, sam_parser) translator.translate(template, {}) + class TestPluginsUsage(TestCase): # Tests if plugins are properly injected into the translator @patch("samtranslator.translator.translator.make_policy_template_for_function_plugin") - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_prepare_plugins_must_add_required_plugins(self, make_policy_template_for_function_plugin_mock): # This is currently the only required plugin @@ -883,7 +882,7 @@ def test_prepare_plugins_must_add_required_plugins(self, make_policy_template_fo self.assertEqual(6, len(sam_plugins)) @patch("samtranslator.translator.translator.make_policy_template_for_function_plugin") - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_prepare_plugins_must_merge_input_plugins(self, make_policy_template_for_function_plugin_mock): required_plugin = BasePlugin("something") @@ -893,7 +892,7 @@ def test_prepare_plugins_must_merge_input_plugins(self, make_policy_template_for sam_plugins = prepare_plugins([custom_plugin]) self.assertEqual(7, len(sam_plugins)) - @patch('botocore.client.ClientEndpointBridge._check_default_region', mock_get_region) + @patch("botocore.client.ClientEndpointBridge._check_default_region", mock_get_region) def test_prepare_plugins_must_handle_empty_input(self): sam_plugins = prepare_plugins(None) @@ -901,9 +900,9 @@ def test_prepare_plugins_must_handle_empty_input(self): @patch("samtranslator.translator.translator.PolicyTemplatesProcessor") @patch("samtranslator.translator.translator.PolicyTemplatesForFunctionPlugin") - def test_make_policy_template_for_function_plugin_must_work(self, - policy_templates_for_function_plugin_mock, - policy_templates_processor_mock): + def test_make_policy_template_for_function_plugin_must_work( + self, policy_templates_for_function_plugin_mock, policy_templates_processor_mock + ): default_templates = {"some": "value"} policy_templates_processor_mock.get_default_policy_templates_json.return_value = default_templates @@ -927,50 +926,46 @@ def test_make_policy_template_for_function_plugin_must_work(self, @patch.object(Resource, "from_dict") @patch("samtranslator.translator.translator.SamPlugins") @patch("samtranslator.translator.translator.prepare_plugins") - @patch('boto3.session.Session.region_name', 'ap-southeast-1') - def test_transform_method_must_inject_plugins_when_creating_resources(self, - prepare_plugins_mock, - sam_plugins_class_mock, - resource_from_dict_mock): - manifest = { - 'Resources': { - 'MyTable': { - 'Type': 'AWS::Serverless::SimpleTable', - 'Properties': { - } - } - } - } + @patch("boto3.session.Session.region_name", "ap-southeast-1") + def test_transform_method_must_inject_plugins_when_creating_resources( + self, prepare_plugins_mock, sam_plugins_class_mock, resource_from_dict_mock + ): + manifest = {"Resources": {"MyTable": {"Type": "AWS::Serverless::SimpleTable", "Properties": {}}}} sam_plugins_object_mock = Mock() sam_plugins_class_mock.return_value = sam_plugins_object_mock prepare_plugins_mock.return_value = sam_plugins_object_mock resource_from_dict_mock.return_value = SamSimpleTable("MyFunction") - initial_plugins = [1,2,3] + initial_plugins = [1, 2, 3] sam_parser = Parser() translator = Translator({}, sam_parser, plugins=initial_plugins) translator.translate(manifest, {}) - resource_from_dict_mock.assert_called_with("MyTable", - manifest["Resources"]["MyTable"], - sam_plugins=sam_plugins_object_mock) - prepare_plugins_mock.assert_called_once_with(initial_plugins, {"AWS::Region": "ap-southeast-1", "AWS::Partition": "aws"}) + resource_from_dict_mock.assert_called_with( + "MyTable", manifest["Resources"]["MyTable"], sam_plugins=sam_plugins_object_mock + ) + prepare_plugins_mock.assert_called_once_with( + initial_plugins, {"AWS::Region": "ap-southeast-1", "AWS::Partition": "aws"} + ) + def get_policy_mock(): mock_policy_loader = MagicMock() mock_policy_loader.load.return_value = { - 'AmazonDynamoDBFullAccess': 'arn:aws:iam::aws:policy/AmazonDynamoDBFullAccess', - 'AmazonDynamoDBReadOnlyAccess': 'arn:aws:iam::aws:policy/AmazonDynamoDBReadOnlyAccess', - 'AWSLambdaRole': 'arn:aws:iam::aws:policy/service-role/AWSLambdaRole', + "AmazonDynamoDBFullAccess": "arn:aws:iam::aws:policy/AmazonDynamoDBFullAccess", + "AmazonDynamoDBReadOnlyAccess": "arn:aws:iam::aws:policy/AmazonDynamoDBReadOnlyAccess", + "AWSLambdaRole": "arn:aws:iam::aws:policy/service-role/AWSLambdaRole", } return mock_policy_loader + def get_deployment_key(fragment): logical_id, value = get_resource_by_type(fragment, "AWS::ApiGateway::Deployment") return logical_id + def get_resource_by_type(template, type): resources = template["Resources"] for key in resources: @@ -978,5 +973,6 @@ def get_resource_by_type(template, type): if "Type" in value and value.get("Type") == type: return key, value + def get_exception_error_message(e): - return reduce(lambda message, error: message + ' ' + error.message, e.value.causes, e.value.message) + return reduce(lambda message, error: message + " " + error.message, e.value.causes, e.value.message) diff --git a/tests/translator/validator/test_validator.py b/tests/translator/validator/test_validator.py index e1fa43a6dc..ceff650a8b 100644 --- a/tests/translator/validator/test_validator.py +++ b/tests/translator/validator/test_validator.py @@ -5,112 +5,116 @@ from samtranslator.validator.validator import SamTemplateValidator BASE_PATH = os.path.dirname(__file__) -INPUT_FOLDER = os.path.join(BASE_PATH, os.pardir, 'input') +INPUT_FOLDER = os.path.join(BASE_PATH, os.pardir, "input") -@pytest.mark.parametrize('testcase', [ - 'cloudwatchevent', - 'eventbridgerule', - 'cloudwatch_logs_with_ref', - 'cloudwatchlog', - 'streams', - 'sqs', - 'simpletable', - 'simpletable_with_sse', - 'implicit_api', - 'explicit_api', - 'api_endpoint_configuration', - 'api_with_method_settings', - 'api_with_binary_media_types', - 'api_with_minimum_compression_size', - 'api_with_resource_refs', - 'api_with_cors', - 'api_with_cors_and_only_methods', - 'api_with_cors_and_only_headers', - 'api_with_cors_and_only_origins', - 'api_with_cors_and_only_maxage', - 'api_cache', - 's3', - 's3_create_remove', - 's3_existing_lambda_notification_configuration', - 's3_existing_other_notification_configuration', - 's3_filter', - 's3_multiple_events_same_bucket', - 's3_multiple_functions', - 'sns', - 'sns_existing_other_subscription', - 'sns_topic_outside_template', - 'alexa_skill', - 'iot_rule', - 'function_managed_inline_policy', - 'unsupported_resources', - 'intrinsic_functions', - 'basic_function_with_tags', - 'depends_on', - 'function_with_dlq', - 'function_with_kmskeyarn', - 'function_with_alias', - 'function_with_alias_intrinsics', - 'function_with_disabled_deployment_preference', - 'function_with_deployment_preference', - 'function_with_deployment_preference_all_parameters', - 'function_with_deployment_preference_multiple_combinations', - 'function_with_alias_and_event_sources', - 'function_with_resource_refs', - 'function_with_deployment_and_custom_role', - 'function_with_deployment_no_service_role', - 'function_with_permissions_boundary', - 'function_with_policy_templates', - 'function_with_sns_event_source_all_parameters', - 'globals_for_function', - 'globals_for_api', - 'globals_for_simpletable', - 'all_policy_templates', - 'simple_table_ref_parameter_intrinsic', - 'simple_table_with_table_name', - 'function_concurrency', - 'simple_table_with_extra_tags', - 'explicit_api_with_invalid_events_config' -]) + +@pytest.mark.parametrize( + "testcase", + [ + "cloudwatchevent", + "eventbridgerule", + "cloudwatch_logs_with_ref", + "cloudwatchlog", + "streams", + "sqs", + "simpletable", + "simpletable_with_sse", + "implicit_api", + "explicit_api", + "api_endpoint_configuration", + "api_with_method_settings", + "api_with_binary_media_types", + "api_with_minimum_compression_size", + "api_with_resource_refs", + "api_with_cors", + "api_with_cors_and_only_methods", + "api_with_cors_and_only_headers", + "api_with_cors_and_only_origins", + "api_with_cors_and_only_maxage", + "api_cache", + "s3", + "s3_create_remove", + "s3_existing_lambda_notification_configuration", + "s3_existing_other_notification_configuration", + "s3_filter", + "s3_multiple_events_same_bucket", + "s3_multiple_functions", + "sns", + "sns_existing_other_subscription", + "sns_topic_outside_template", + "alexa_skill", + "iot_rule", + "function_managed_inline_policy", + "unsupported_resources", + "intrinsic_functions", + "basic_function_with_tags", + "depends_on", + "function_with_dlq", + "function_with_kmskeyarn", + "function_with_alias", + "function_with_alias_intrinsics", + "function_with_disabled_deployment_preference", + "function_with_deployment_preference", + "function_with_deployment_preference_all_parameters", + "function_with_deployment_preference_multiple_combinations", + "function_with_alias_and_event_sources", + "function_with_resource_refs", + "function_with_deployment_and_custom_role", + "function_with_deployment_no_service_role", + "function_with_permissions_boundary", + "function_with_policy_templates", + "function_with_sns_event_source_all_parameters", + "globals_for_function", + "globals_for_api", + "globals_for_simpletable", + "all_policy_templates", + "simple_table_ref_parameter_intrinsic", + "simple_table_with_table_name", + "function_concurrency", + "simple_table_with_extra_tags", + "explicit_api_with_invalid_events_config", + ], +) def test_validate_template_success(testcase): # These templates are failing validation, will fix schema one at a time excluded = [ - 'api_endpoint_configuration', - 'api_with_binary_media_types', - 'api_with_minimum_compression_size', - 'api_with_cors', - 'cloudwatch_logs_with_ref', - 'sns', - 'sns_existing_other_subscription', - 'sns_topic_outside_template', - 'alexa_skill', - 'iot_rule', - 'function_managed_inline_policy', - 'unsupported_resources', - 'intrinsic_functions', - 'basic_function_with_tags', - 'function_with_kmskeyarn', - 'function_with_alias', - 'function_with_alias_intrinsics', - 'function_with_disabled_deployment_preference', - 'function_with_deployment_preference', - 'function_with_deployment_preference_all_parameters', - 'function_with_deployment_preference_multiple_combinations', - 'function_with_alias_and_event_sources', - 'function_with_resource_refs', - 'function_with_policy_templates', - 'globals_for_function', - 'all_policy_templates', - 'simple_table_ref_parameter_intrinsic', - 'simple_table_with_table_name', - 'function_concurrency', - 'simple_table_with_extra_tags' + "api_endpoint_configuration", + "api_with_binary_media_types", + "api_with_minimum_compression_size", + "api_with_cors", + "cloudwatch_logs_with_ref", + "sns", + "sns_existing_other_subscription", + "sns_topic_outside_template", + "alexa_skill", + "iot_rule", + "function_managed_inline_policy", + "unsupported_resources", + "intrinsic_functions", + "basic_function_with_tags", + "function_with_kmskeyarn", + "function_with_alias", + "function_with_alias_intrinsics", + "function_with_disabled_deployment_preference", + "function_with_deployment_preference", + "function_with_deployment_preference_all_parameters", + "function_with_deployment_preference_multiple_combinations", + "function_with_alias_and_event_sources", + "function_with_resource_refs", + "function_with_policy_templates", + "globals_for_function", + "all_policy_templates", + "simple_table_ref_parameter_intrinsic", + "simple_table_with_table_name", + "function_concurrency", + "simple_table_with_extra_tags", ] if testcase in excluded: return - manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + '.yaml'), 'r')) + manifest = yaml_parse(open(os.path.join(INPUT_FOLDER, testcase + ".yaml"), "r")) validation_errors = SamTemplateValidator.validate(manifest) has_errors = len(validation_errors) if has_errors: - print("\nFailing template: {0}\n".format(testcase)) - print(validation_errors) + print ("\nFailing template: {0}\n".format(testcase)) + print (validation_errors) assert len(validation_errors) == 0 diff --git a/tox.ini b/tox.ini index 6fc4e4feb8..62a8b9fe70 100644 --- a/tox.ini +++ b/tox.ini @@ -25,4 +25,4 @@ commands = make pr codecov deps = codecov>=1.4.0 passenv = AWS* TONXENV CI TRAVIS TRAVIS_* -whitelist_externals = make +whitelist_externals = make, black