|
3 | 3 |
|
4 | 4 | from samtranslator.model.intrinsics import ref
|
5 | 5 | from samtranslator.model.apigateway import (ApiGatewayDeployment, ApiGatewayRestApi,
|
6 |
| - ApiGatewayStage, ApiGatewayAuthorizer) |
| 6 | + ApiGatewayStage, ApiGatewayAuthorizer, |
| 7 | + ApiGatewayResponse) |
7 | 8 | from samtranslator.model.exceptions import InvalidResourceException
|
8 | 9 | from samtranslator.model.s3_utils.uri_parser import parse_s3_uri
|
9 | 10 | from samtranslator.region_configuration import RegionConfiguration
|
|
21 | 22 | AuthProperties = namedtuple("_AuthProperties", ["Authorizers", "DefaultAuthorizer", "InvokeRole"])
|
22 | 23 | AuthProperties.__new__.__defaults__ = (None, None, None)
|
23 | 24 |
|
| 25 | +GatewayResponseProperties = ["ResponseParameters", "ResponseTemplates", "StatusCode"] |
| 26 | + |
24 | 27 |
|
25 | 28 | class ApiGenerator(object):
|
26 | 29 |
|
27 | 30 | def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variables, depends_on,
|
28 | 31 | definition_body, definition_uri, name, stage_name, endpoint_configuration=None,
|
29 | 32 | method_settings=None, binary_media=None, minimum_compression_size=None, cors=None,
|
30 |
| - auth=None, access_log_setting=None, canary_setting=None, tracing_enabled=None, |
31 |
| - resource_attributes=None, passthrough_resource_attributes=None): |
| 33 | + auth=None, gateway_responses=None, access_log_setting=None, canary_setting=None, |
| 34 | + tracing_enabled=None, resource_attributes=None, passthrough_resource_attributes=None): |
32 | 35 | """Constructs an API Generator class that generates API Gateway resources
|
33 | 36 |
|
34 | 37 | :param logical_id: Logical id of the SAM API Resource
|
@@ -61,6 +64,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
|
61 | 64 | self.minimum_compression_size = minimum_compression_size
|
62 | 65 | self.cors = cors
|
63 | 66 | self.auth = auth
|
| 67 | + self.gateway_responses = gateway_responses |
64 | 68 | self.access_log_setting = access_log_setting
|
65 | 69 | self.canary_setting = canary_setting
|
66 | 70 | self.tracing_enabled = tracing_enabled
|
@@ -91,6 +95,7 @@ def _construct_rest_api(self):
|
91 | 95 |
|
92 | 96 | self._add_cors()
|
93 | 97 | self._add_auth()
|
| 98 | + self._add_gateway_responses() |
94 | 99 |
|
95 | 100 | if self.definition_uri:
|
96 | 101 | rest_api.BodyS3Location = self._construct_body_s3_dict()
|
@@ -275,6 +280,49 @@ def _add_auth(self):
|
275 | 280 | # Assign the Swagger back to template
|
276 | 281 | self.definition_body = swagger_editor.swagger
|
277 | 282 |
|
| 283 | + def _add_gateway_responses(self): |
| 284 | + """ |
| 285 | + Add Gateway Response configuration to the Swagger file, if necessary |
| 286 | + """ |
| 287 | + |
| 288 | + if not self.gateway_responses: |
| 289 | + return |
| 290 | + |
| 291 | + if self.gateway_responses and not self.definition_body: |
| 292 | + raise InvalidResourceException( |
| 293 | + self.logical_id, "GatewayResponses works only with inline Swagger specified in " |
| 294 | + "'DefinitionBody' property") |
| 295 | + |
| 296 | + # Make sure keys in the dict are recognized |
| 297 | + for responses_key, responses_value in self.gateway_responses.items(): |
| 298 | + for response_key in responses_value.keys(): |
| 299 | + if response_key not in GatewayResponseProperties: |
| 300 | + raise InvalidResourceException( |
| 301 | + self.logical_id, |
| 302 | + "Invalid property '{}' in 'GatewayResponses' property '{}'".format(response_key, responses_key)) |
| 303 | + |
| 304 | + if not SwaggerEditor.is_valid(self.definition_body): |
| 305 | + raise InvalidResourceException( |
| 306 | + self.logical_id, "Unable to add Auth configuration because " |
| 307 | + "'DefinitionBody' does not contain a valid Swagger") |
| 308 | + |
| 309 | + swagger_editor = SwaggerEditor(self.definition_body) |
| 310 | + |
| 311 | + gateway_responses = {} |
| 312 | + for response_type, response in self.gateway_responses.items(): |
| 313 | + gateway_responses[response_type] = ApiGatewayResponse( |
| 314 | + api_logical_id=self.logical_id, |
| 315 | + response_parameters=response.get('ResponseParameters', {}), |
| 316 | + response_templates=response.get('ResponseTemplates', {}), |
| 317 | + status_code=response.get('StatusCode', None) |
| 318 | + ) |
| 319 | + |
| 320 | + if gateway_responses: |
| 321 | + swagger_editor.add_gateway_responses(gateway_responses) |
| 322 | + |
| 323 | + # Assign the Swagger back to template |
| 324 | + self.definition_body = swagger_editor.swagger |
| 325 | + |
278 | 326 | def _get_authorizers(self, authorizers_config, default_authorizer=None):
|
279 | 327 | authorizers = {}
|
280 | 328 | if default_authorizer == 'AWS_IAM':
|
|
0 commit comments