@@ -33,7 +33,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
33
33
method_settings = None , binary_media = None , minimum_compression_size = None , cors = None ,
34
34
auth = None , gateway_responses = None , access_log_setting = None , canary_setting = None ,
35
35
tracing_enabled = None , resource_attributes = None , passthrough_resource_attributes = None ,
36
- open_api_version = None ):
36
+ open_api_version = None , models = None ):
37
37
"""Constructs an API Generator class that generates API Gateway resources
38
38
39
39
:param logical_id: Logical id of the SAM API Resource
@@ -50,6 +50,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
50
50
:param tracing_enabled: Whether active tracing with X-ray is enabled
51
51
:param resource_attributes: Resource attributes to add to API resources
52
52
:param passthrough_resource_attributes: Attributes such as `Condition` that are added to derived resources
53
+ :param models: Model definitions to be used by API methods
53
54
"""
54
55
self .logical_id = logical_id
55
56
self .cache_cluster_enabled = cache_cluster_enabled
@@ -73,6 +74,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
73
74
self .resource_attributes = resource_attributes
74
75
self .passthrough_resource_attributes = passthrough_resource_attributes
75
76
self .open_api_version = open_api_version
77
+ self .models = models
76
78
77
79
def _construct_rest_api (self ):
78
80
"""Constructs and returns the ApiGateway RestApi.
@@ -107,6 +109,7 @@ def _construct_rest_api(self):
107
109
self ._add_auth ()
108
110
self ._add_gateway_responses ()
109
111
self ._add_binary_media_types ()
112
+ self ._add_models ()
110
113
111
114
if self .definition_uri :
112
115
rest_api .BodyS3Location = self ._construct_body_s3_dict ()
@@ -327,8 +330,9 @@ def _openapi_auth_postprocess(self, definition_body):
327
330
328
331
if self .open_api_version and re .match (SwaggerEditor .get_openapi_version_3_regex (), self .open_api_version ):
329
332
if definition_body .get ('securityDefinitions' ):
330
- definition_body ['components' ] = {}
331
- definition_body ['components' ]['securitySchemes' ] = definition_body ['securityDefinitions' ]
333
+ components = definition_body .get ('components' , {})
334
+ components ['securitySchemes' ] = definition_body ['securityDefinitions' ]
335
+ definition_body ['components' ] = components
332
336
del definition_body ['securityDefinitions' ]
333
337
return definition_body
334
338
@@ -375,6 +379,56 @@ def _add_gateway_responses(self):
375
379
# Assign the Swagger back to template
376
380
self .definition_body = swagger_editor .swagger
377
381
382
+ def _add_models (self ):
383
+ """
384
+ Add Model definitions to the Swagger file, if necessary
385
+ :return:
386
+ """
387
+
388
+ if not self .models :
389
+ return
390
+
391
+ if self .models and not self .definition_body :
392
+ raise InvalidResourceException (self .logical_id ,
393
+ "Models works only with inline Swagger specified in "
394
+ "'DefinitionBody' property" )
395
+
396
+ if not SwaggerEditor .is_valid (self .definition_body ):
397
+ raise InvalidResourceException (self .logical_id , "Unable to add Models definitions because "
398
+ "'DefinitionBody' does not contain a valid Swagger" )
399
+
400
+ if not all (isinstance (model , dict ) for model in self .models .values ()):
401
+ raise InvalidResourceException (self .logical_id , "Invalid value for 'Models' property" )
402
+
403
+ swagger_editor = SwaggerEditor (self .definition_body )
404
+ swagger_editor .add_models (self .models )
405
+
406
+ # Assign the Swagger back to template
407
+
408
+ self .definition_body = self ._openapi_models_postprocess (swagger_editor .swagger )
409
+
410
+ def _openapi_models_postprocess (self , definition_body ):
411
+ """
412
+ Convert definitions to openapi 3 in definition body if OpenApiVersion flag is specified.
413
+
414
+ If the is swagger defined in the definition body, we treat it as a swagger spec and dod not
415
+ make any openapi 3 changes to it
416
+ """
417
+ if definition_body .get ('swagger' ) is not None :
418
+ return definition_body
419
+
420
+ if definition_body .get ('openapi' ) is not None :
421
+ if self .open_api_version is None :
422
+ self .open_api_version = definition_body .get ('openapi' )
423
+
424
+ if self .open_api_version and re .match (SwaggerEditor .get_openapi_version_3_regex (), self .open_api_version ):
425
+ if definition_body .get ('definitions' ):
426
+ components = definition_body .get ('components' , {})
427
+ components ['schemas' ] = definition_body ['definitions' ]
428
+ definition_body ['components' ] = components
429
+ del definition_body ['definitions' ]
430
+ return definition_body
431
+
378
432
def _get_authorizers (self , authorizers_config , default_authorizer = None ):
379
433
authorizers = {}
380
434
if default_authorizer == 'AWS_IAM' :
0 commit comments