Skip to content

Commit 1bb186c

Browse files
beck3905keetonian
authored andcommitted
feat: add API Request Model support (#948)
1 parent a7291d4 commit 1bb186c

19 files changed

+1651
-7
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
exports.handler = function(event, context, callback) {
2+
callback(null, {
3+
"statusCode": 200,
4+
"body": "hello world"
5+
});
6+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
AWSTemplateFormatVersion: '2010-09-09'
2+
Transform: AWS::Serverless-2016-10-31
3+
Description: Simple API Endpoint configured using Swagger specified inline and backed by a Lambda function
4+
Resources:
5+
MyApi:
6+
Type: AWS::Serverless::Api
7+
Properties:
8+
StageName: prod
9+
Models:
10+
User:
11+
type: object
12+
required:
13+
- grant_type
14+
- username
15+
- password
16+
properties:
17+
grant_type:
18+
type: string
19+
username:
20+
type: string
21+
password:
22+
type: string
23+
24+
MyLambdaFunction:
25+
Type: AWS::Serverless::Function
26+
Properties:
27+
Handler: index.handler
28+
Runtime: nodejs6.10
29+
CodeUri: src/
30+
Events:
31+
GetApi:
32+
Type: Api
33+
Properties:
34+
Path: /post
35+
Method: POST
36+
RestApiId:
37+
Ref: MyApi
38+
RequestModel:
39+
Model: User
40+
Required: true
41+
42+
Outputs:
43+
44+
ApiURL:
45+
Description: "API endpoint URL for Prod environment"
46+
Value: !Sub "https://${MyApi}.execute-api.${AWS::Region}.amazonaws.com/prod/get"

samtranslator/model/api/api_generator.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
3333
method_settings=None, binary_media=None, minimum_compression_size=None, cors=None,
3434
auth=None, gateway_responses=None, access_log_setting=None, canary_setting=None,
3535
tracing_enabled=None, resource_attributes=None, passthrough_resource_attributes=None,
36-
open_api_version=None):
36+
open_api_version=None, models=None):
3737
"""Constructs an API Generator class that generates API Gateway resources
3838
3939
:param logical_id: Logical id of the SAM API Resource
@@ -50,6 +50,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
5050
:param tracing_enabled: Whether active tracing with X-ray is enabled
5151
:param resource_attributes: Resource attributes to add to API resources
5252
:param passthrough_resource_attributes: Attributes such as `Condition` that are added to derived resources
53+
:param models: Model definitions to be used by API methods
5354
"""
5455
self.logical_id = logical_id
5556
self.cache_cluster_enabled = cache_cluster_enabled
@@ -73,6 +74,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
7374
self.resource_attributes = resource_attributes
7475
self.passthrough_resource_attributes = passthrough_resource_attributes
7576
self.open_api_version = open_api_version
77+
self.models = models
7678

7779
def _construct_rest_api(self):
7880
"""Constructs and returns the ApiGateway RestApi.
@@ -107,6 +109,7 @@ def _construct_rest_api(self):
107109
self._add_auth()
108110
self._add_gateway_responses()
109111
self._add_binary_media_types()
112+
self._add_models()
110113

111114
if self.definition_uri:
112115
rest_api.BodyS3Location = self._construct_body_s3_dict()
@@ -327,8 +330,9 @@ def _openapi_auth_postprocess(self, definition_body):
327330

328331
if self.open_api_version and re.match(SwaggerEditor.get_openapi_version_3_regex(), self.open_api_version):
329332
if definition_body.get('securityDefinitions'):
330-
definition_body['components'] = {}
331-
definition_body['components']['securitySchemes'] = definition_body['securityDefinitions']
333+
components = definition_body.get('components', {})
334+
components['securitySchemes'] = definition_body['securityDefinitions']
335+
definition_body['components'] = components
332336
del definition_body['securityDefinitions']
333337
return definition_body
334338

@@ -375,6 +379,56 @@ def _add_gateway_responses(self):
375379
# Assign the Swagger back to template
376380
self.definition_body = swagger_editor.swagger
377381

382+
def _add_models(self):
383+
"""
384+
Add Model definitions to the Swagger file, if necessary
385+
:return:
386+
"""
387+
388+
if not self.models:
389+
return
390+
391+
if self.models and not self.definition_body:
392+
raise InvalidResourceException(self.logical_id,
393+
"Models works only with inline Swagger specified in "
394+
"'DefinitionBody' property")
395+
396+
if not SwaggerEditor.is_valid(self.definition_body):
397+
raise InvalidResourceException(self.logical_id, "Unable to add Models definitions because "
398+
"'DefinitionBody' does not contain a valid Swagger")
399+
400+
if not all(isinstance(model, dict) for model in self.models.values()):
401+
raise InvalidResourceException(self.logical_id, "Invalid value for 'Models' property")
402+
403+
swagger_editor = SwaggerEditor(self.definition_body)
404+
swagger_editor.add_models(self.models)
405+
406+
# Assign the Swagger back to template
407+
408+
self.definition_body = self._openapi_models_postprocess(swagger_editor.swagger)
409+
410+
def _openapi_models_postprocess(self, definition_body):
411+
"""
412+
Convert definitions to openapi 3 in definition body if OpenApiVersion flag is specified.
413+
414+
If the is swagger defined in the definition body, we treat it as a swagger spec and dod not
415+
make any openapi 3 changes to it
416+
"""
417+
if definition_body.get('swagger') is not None:
418+
return definition_body
419+
420+
if definition_body.get('openapi') is not None:
421+
if self.open_api_version is None:
422+
self.open_api_version = definition_body.get('openapi')
423+
424+
if self.open_api_version and re.match(SwaggerEditor.get_openapi_version_3_regex(), self.open_api_version):
425+
if definition_body.get('definitions'):
426+
components = definition_body.get('components', {})
427+
components['schemas'] = definition_body['definitions']
428+
definition_body['components'] = components
429+
del definition_body['definitions']
430+
return definition_body
431+
378432
def _get_authorizers(self, authorizers_config, default_authorizer=None):
379433
authorizers = {}
380434
if default_authorizer == 'AWS_IAM':

samtranslator/model/eventsources/push.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ class Api(PushEventSource):
386386

387387
# Api Event sources must "always" be paired with a Serverless::Api
388388
'RestApiId': PropertyType(True, is_str()),
389-
'Auth': PropertyType(False, is_type(dict))
389+
'Auth': PropertyType(False, is_type(dict)),
390+
'RequestModel': PropertyType(False, is_type(dict))
390391
}
391392

392393
def resources_to_link(self, resources):
@@ -564,6 +565,28 @@ def _add_swagger_integration(self, api, function):
564565

565566
editor.add_auth_to_method(api=api, path=self.Path, method_name=self.Method, auth=self.Auth)
566567

568+
if self.RequestModel:
569+
method_model = self.RequestModel.get('Model')
570+
571+
if method_model:
572+
api_models = api.get('Models')
573+
if not api_models:
574+
raise InvalidEventException(
575+
self.relative_id,
576+
'Unable to set RequestModel [{model}] on API method [{method}] for path [{path}] '
577+
'because the related API does not define any Models.'.format(
578+
model=method_model, method=self.Method, path=self.Path))
579+
580+
if not api_models.get(method_model):
581+
raise InvalidEventException(
582+
self.relative_id,
583+
'Unable to set RequestModel [{model}] on API method [{method}] for path [{path}] '
584+
'because it wasn\'t defined in the API\'s Models.'.format(
585+
model=method_model, method=self.Method, path=self.Path))
586+
587+
editor.add_request_model_to_method(path=self.Path, method_name=self.Method,
588+
request_model=self.RequestModel)
589+
567590
api["DefinitionBody"] = editor.swagger
568591

569592

samtranslator/model/sam_resources.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,8 @@ class SamApi(SamResourceMacro):
446446
'AccessLogSetting': PropertyType(False, is_type(dict)),
447447
'CanarySetting': PropertyType(False, is_type(dict)),
448448
'TracingEnabled': PropertyType(False, is_type(bool)),
449-
'OpenApiVersion': PropertyType(False, is_str())
449+
'OpenApiVersion': PropertyType(False, is_str()),
450+
'Models': PropertyType(False, is_type(dict))
450451
}
451452

452453
referable_properties = {
@@ -485,7 +486,8 @@ def to_cloudformation(self, **kwargs):
485486
tracing_enabled=self.TracingEnabled,
486487
resource_attributes=self.resource_attributes,
487488
passthrough_resource_attributes=self.get_passthrough_resource_attributes(),
488-
open_api_version=self.OpenApiVersion)
489+
open_api_version=self.OpenApiVersion,
490+
models=self.Models)
489491

490492
rest_api, deployment, stage, permissions = api_generator.to_cloudformation()
491493

samtranslator/swagger/swagger.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(self, doc):
3838
self.paths = self._doc["paths"]
3939
self.security_definitions = self._doc.get("securityDefinitions", {})
4040
self.gateway_responses = self._doc.get(self._X_APIGW_GATEWAY_RESPONSES, {})
41+
self.definitions = self._doc.get('definitions', {})
4142

4243
def get_path(self, path):
4344
path_dict = self.paths.get(path)
@@ -522,6 +523,61 @@ def set_method_authorizer(self, path, method_name, authorizer_name, authorizers,
522523
elif 'AWS_IAM' not in self.security_definitions:
523524
self.security_definitions.update(aws_iam_security_definition)
524525

526+
def add_request_model_to_method(self, path, method_name, request_model):
527+
"""
528+
Adds request model body parameter for this path/method.
529+
530+
:param string path: Path name
531+
:param string method_name: Method name
532+
:param dict request_model: Model name
533+
"""
534+
model_name = request_model and request_model.get('Model').lower()
535+
model_required = request_model and request_model.get('Required')
536+
537+
normalized_method_name = self._normalize_method_name(method_name)
538+
# It is possible that the method could have two definitions in a Fn::If block.
539+
for method_definition in self.get_method_contents(self.get_path(path)[normalized_method_name]):
540+
541+
# If no integration given, then we don't need to process this definition (could be AWS::NoValue)
542+
if not self.method_definition_has_integration(method_definition):
543+
continue
544+
545+
if self._doc.get('swagger') is not None:
546+
547+
existing_parameters = method_definition.get('parameters', [])
548+
549+
parameter = {
550+
'in': 'body',
551+
'name': model_name,
552+
'schema': {
553+
'$ref': '#/definitions/{}'.format(model_name)
554+
}
555+
}
556+
557+
if model_required is not None:
558+
parameter['required'] = model_required
559+
560+
existing_parameters.append(parameter)
561+
562+
method_definition['parameters'] = existing_parameters
563+
564+
elif self._doc.get("openapi") and \
565+
re.search(SwaggerEditor.get_openapi_version_3_regex(), self._doc["openapi"]) is not None:
566+
567+
method_definition['requestBody'] = {
568+
'content': {
569+
"application/json": {
570+
"schema": {
571+
"$ref": "#/components/schemas/{}".format(model_name)
572+
}
573+
}
574+
575+
}
576+
}
577+
578+
if model_required is not None:
579+
method_definition['requestBody']['required'] = model_required
580+
525581
def add_gateway_responses(self, gateway_responses):
526582
"""
527583
Add Gateway Response definitions to Swagger.
@@ -533,6 +589,29 @@ def add_gateway_responses(self, gateway_responses):
533589
for response_type, response in gateway_responses.items():
534590
self.gateway_responses[response_type] = response.generate_swagger()
535591

592+
def add_models(self, models):
593+
"""
594+
Add Model definitions to Swagger.
595+
596+
:param dict models: Dictionary of Model schemas which gets translated
597+
:return:
598+
"""
599+
600+
self.definitions = self.definitions or {}
601+
602+
for model_name, schema in models.items():
603+
604+
model_type = schema.get('type')
605+
model_properties = schema.get('properties')
606+
607+
if not model_type:
608+
raise ValueError("Invalid input. Value for type is required")
609+
610+
if not model_properties:
611+
raise ValueError("Invalid input. Value for properties is required")
612+
613+
self.definitions[model_name.lower()] = schema
614+
536615
@property
537616
def swagger(self):
538617
"""
@@ -548,6 +627,8 @@ def swagger(self):
548627
self._doc["securityDefinitions"] = self.security_definitions
549628
if self.gateway_responses:
550629
self._doc[self._X_APIGW_GATEWAY_RESPONSES] = self.gateway_responses
630+
if self.definitions:
631+
self._doc['definitions'] = self.definitions
551632

552633
return copy.deepcopy(self._doc)
553634

0 commit comments

Comments
 (0)