3
3
from samtranslator .model .intrinsics import ref
4
4
from samtranslator .model .apigateway import (ApiGatewayDeployment , ApiGatewayRestApi ,
5
5
ApiGatewayStage , ApiGatewayAuthorizer ,
6
- ApiGatewayResponse )
6
+ ApiGatewayResponse , ApiGatewayDomainName ,
7
+ ApiGatewayBasePathMapping )
7
8
from samtranslator .model .exceptions import InvalidResourceException
8
9
from samtranslator .model .s3_utils .uri_parser import parse_s3_uri
9
10
from samtranslator .region_configuration import RegionConfiguration
@@ -35,7 +36,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
35
36
method_settings = None , binary_media = None , minimum_compression_size = None , cors = None ,
36
37
auth = None , gateway_responses = None , access_log_setting = None , canary_setting = None ,
37
38
tracing_enabled = None , resource_attributes = None , passthrough_resource_attributes = None ,
38
- open_api_version = None , models = None ):
39
+ open_api_version = None , models = None , domain = None ):
39
40
"""Constructs an API Generator class that generates API Gateway resources
40
41
41
42
:param logical_id: Logical id of the SAM API Resource
@@ -80,6 +81,7 @@ def __init__(self, logical_id, cache_cluster_enabled, cache_cluster_size, variab
80
81
self .open_api_version = open_api_version
81
82
self .remove_extra_stage = open_api_version
82
83
self .models = models
84
+ self .domain = domain
83
85
84
86
def _construct_rest_api (self ):
85
87
"""Constructs and returns the ApiGateway RestApi.
@@ -204,20 +206,86 @@ def _construct_stage(self, deployment, swagger):
204
206
stage .TracingEnabled = self .tracing_enabled
205
207
206
208
if swagger is not None :
207
- deployment .make_auto_deployable (stage , self .remove_extra_stage , swagger )
209
+ deployment .make_auto_deployable (stage , self .remove_extra_stage , swagger , self . domain )
208
210
209
211
if self .tags is not None :
210
212
stage .Tags = get_tag_list (self .tags )
211
213
212
214
return stage
213
215
216
+ def _construct_api_domain (self , rest_api ):
217
+ """
218
+ Constructs and returns the ApiGateway Domain and BasepathMapping
219
+ """
220
+ if self .domain is None :
221
+ return None , None
222
+
223
+ if self .domain .get ('DomainName' ) is None or \
224
+ self .domain .get ('CertificateArn' ) is None :
225
+ raise InvalidResourceException (self .logical_id ,
226
+ "Custom Domains only works if both DomainName and CertificateArn"
227
+ " are provided" )
228
+
229
+ logical_id = logical_id_generator .LogicalIdGenerator ("" , self .domain ).gen ()
230
+
231
+ domain = ApiGatewayDomainName ('ApiGatewayDomainName' + logical_id ,
232
+ attributes = self .passthrough_resource_attributes )
233
+ domain .DomainName = self .domain .get ('DomainName' )
234
+ endpoint = self .domain .get ('EndpointConfiguration' )
235
+
236
+ if endpoint is None :
237
+ endpoint = 'REGIONAL'
238
+ elif endpoint not in ['EDGE' , 'REGIONAL' ]:
239
+ raise InvalidResourceException (self .logical_id ,
240
+ "EndpointConfiguration for Custom Domains must be"
241
+ " one of {}" .format (['EDGE' , 'REGIONAL' ]))
242
+
243
+ if endpoint == 'REGIONAL' :
244
+ domain .RegionalCertificateArn = self .domain .get ('CertificateArn' )
245
+ else :
246
+ domain .CertificateArn = self .domain .get ('CertificateArn' )
247
+
248
+ domain .EndpointConfiguration = {"Types" : [endpoint ]}
249
+
250
+ # Create BasepathMappings
251
+ if self .domain .get ('BasePath' ) and isinstance (self .domain .get ('BasePath' ), string_types ):
252
+ basepaths = [self .domain .get ('BasePath' )]
253
+ elif self .domain .get ('BasePath' ) and isinstance (self .domain .get ('BasePath' ), list ):
254
+ basepaths = self .domain .get ('BasePath' )
255
+ else :
256
+ basepaths = None
257
+
258
+ basepath_resource_list = []
259
+
260
+ if basepaths is None :
261
+ basepath_mapping = ApiGatewayBasePathMapping (self .logical_id + 'BasePathMapping' ,
262
+ attributes = self .passthrough_resource_attributes )
263
+ basepath_mapping .DomainName = self .domain .get ('DomainName' )
264
+ basepath_mapping .RestApiId = ref (rest_api .logical_id )
265
+ basepath_mapping .Stage = ref (rest_api .logical_id + '.Stage' )
266
+ basepath_resource_list .extend ([basepath_mapping ])
267
+ else :
268
+ for path in basepaths :
269
+ path = '' .join (e for e in path if e .isalnum ())
270
+ logical_id = "{}{}{}" .format (self .logical_id , path , 'BasePathMapping' )
271
+ basepath_mapping = ApiGatewayBasePathMapping (logical_id ,
272
+ attributes = self .passthrough_resource_attributes )
273
+ basepath_mapping .DomainName = self .domain .get ('DomainName' )
274
+ basepath_mapping .RestApiId = ref (rest_api .logical_id )
275
+ basepath_mapping .Stage = ref (rest_api .logical_id + '.Stage' )
276
+ basepath_mapping .BasePath = path
277
+ basepath_resource_list .extend ([basepath_mapping ])
278
+
279
+ return domain , basepath_resource_list
280
+
214
281
def to_cloudformation (self ):
215
282
"""Generates CloudFormation resources from a SAM API resource
216
283
217
284
:returns: a tuple containing the RestApi, Deployment, and Stage for an empty Api.
218
285
:rtype: tuple
219
286
"""
220
287
rest_api = self ._construct_rest_api ()
288
+ domain , basepath_mapping = self ._construct_api_domain (rest_api )
221
289
deployment = self ._construct_deployment (rest_api )
222
290
223
291
swagger = None
@@ -229,7 +297,7 @@ def to_cloudformation(self):
229
297
stage = self ._construct_stage (deployment , swagger )
230
298
permissions = self ._construct_authorizer_lambda_permission ()
231
299
232
- return rest_api , deployment , stage , permissions
300
+ return rest_api , deployment , stage , permissions , domain , basepath_mapping
233
301
234
302
def _add_cors (self ):
235
303
"""
0 commit comments