Skip to content

Commit 16aefe7

Browse files
authored
feat: Allow Boto3 Session to be configured (#1686)
Co-authored-by: Jacob Fuss <[email protected]>
1 parent afcb3d3 commit 16aefe7

File tree

5 files changed

+80
-16
lines changed

5 files changed

+80
-16
lines changed

samtranslator/sdk/parameter.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import boto3
22
import copy
33

4+
from samtranslator.translator.arn_generator import ArnGenerator, NoRegionFound
5+
46

57
class SamParameterValues(object):
68
"""
@@ -58,21 +60,20 @@ def add_default_parameter_values(self, sam_template):
5860
if param_name not in self.parameter_values and isinstance(value, dict) and "Default" in value:
5961
self.parameter_values[param_name] = value["Default"]
6062

61-
def add_pseudo_parameter_values(self):
63+
def add_pseudo_parameter_values(self, session=None):
6264
"""
6365
Add pseudo parameter values
6466
:return: parameter values that have pseudo parameter in it
6567
"""
68+
69+
if session is None:
70+
session = boto3.session.Session()
71+
72+
if not session.region_name:
73+
raise NoRegionFound("AWS Region cannot be found")
74+
6675
if "AWS::Region" not in self.parameter_values:
67-
self.parameter_values["AWS::Region"] = boto3.session.Session().region_name
76+
self.parameter_values["AWS::Region"] = session.region_name
6877

6978
if "AWS::Partition" not in self.parameter_values:
70-
region = boto3.session.Session().region_name
71-
72-
# neither boto nor botocore has any way of returning the partition value yet
73-
if region.startswith("cn-"):
74-
self.parameter_values["AWS::Partition"] = "aws-cn"
75-
elif region.startswith("us-gov-"):
76-
self.parameter_values["AWS::Partition"] = "aws-us-gov"
77-
else:
78-
self.parameter_values["AWS::Partition"] = "aws"
79+
self.parameter_values["AWS::Partition"] = ArnGenerator.get_partition_name(session.region_name)

samtranslator/translator/arn_generator.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import boto3
22

33

4+
class NoRegionFound(Exception):
5+
pass
6+
7+
48
class ArnGenerator(object):
9+
class_boto_session = None
10+
511
@classmethod
612
def generate_arn(cls, partition, service, resource, include_account_id=True):
713
if not service or not resource:
@@ -43,7 +49,17 @@ def get_partition_name(cls, region=None):
4349
if region is None:
4450
# Use Boto3 to get the region where code is running. This uses Boto's regular region resolution
4551
# mechanism, starting from AWS_DEFAULT_REGION environment variable.
46-
region = boto3.session.Session().region_name
52+
53+
if ArnGenerator.class_boto_session is None:
54+
region = boto3.session.Session().region_name
55+
else:
56+
region = ArnGenerator.class_boto_session.region_name
57+
58+
# If region is still None, then we could not find the region. This will only happen
59+
# in the local context. When this is deployed, we will be able to find the region like
60+
# we did before.
61+
if region is None:
62+
raise NoRegionFound("AWS Region cannot be found")
4763

4864
# setting default partition to aws, this will be overwritten by checking the region below
4965
partition = "aws"

samtranslator/translator/translator.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
from samtranslator.plugins.policies.policy_templates_plugin import PolicyTemplatesForResourcePlugin
2626
from samtranslator.policy_template_processor.processor import PolicyTemplatesProcessor
2727
from samtranslator.sdk.parameter import SamParameterValues
28+
from samtranslator.translator.arn_generator import ArnGenerator
2829

2930

3031
class Translator:
3132
"""Translates SAM templates into CloudFormation templates"""
3233

33-
def __init__(self, managed_policy_map, sam_parser, plugins=None):
34+
def __init__(self, managed_policy_map, sam_parser, plugins=None, boto_session=None):
3435
"""
3536
:param dict managed_policy_map: Map of managed policy names to the ARNs
3637
:param sam_parser: Instance of a SAM Parser
@@ -41,6 +42,9 @@ def __init__(self, managed_policy_map, sam_parser, plugins=None):
4142
self.plugins = plugins
4243
self.sam_parser = sam_parser
4344
self.feature_toggle = None
45+
self.boto_session = boto_session
46+
47+
ArnGenerator.class_boto_session = self.boto_session
4448

4549
def _get_function_names(self, resource_dict, intrinsics_resolver):
4650
"""
@@ -92,7 +96,7 @@ def translate(self, sam_template, parameter_values, feature_toggle=None):
9296
self.redeploy_restapi_parameters = dict()
9397
sam_parameter_values = SamParameterValues(parameter_values)
9498
sam_parameter_values.add_default_parameter_values(sam_template)
95-
sam_parameter_values.add_pseudo_parameter_values()
99+
sam_parameter_values.add_pseudo_parameter_values(self.boto_session)
96100
parameter_values = sam_parameter_values.parameter_values
97101
# Create & Install plugins
98102
sam_plugins = prepare_plugins(self.plugins, parameter_values)

tests/sdk/test_parameter.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from parameterized import parameterized, param
22

3-
import pytest
43
from unittest import TestCase
54
from samtranslator.sdk.parameter import SamParameterValues
6-
from mock import patch
5+
from mock import patch, Mock
6+
7+
from samtranslator.translator.arn_generator import NoRegionFound
78

89

910
class TestSAMParameterValues(TestCase):
@@ -101,3 +102,10 @@ def test_add_pseudo_parameter_values_aws_partition_not_override(self):
101102
sam_parameter_values = SamParameterValues(parameter_values)
102103
sam_parameter_values.add_pseudo_parameter_values()
103104
self.assertEqual(expected, sam_parameter_values.parameter_values)
105+
106+
def test_add_pseudo_parameter_values_raises_NoRegionFound(self):
107+
boto_session_mock = Mock()
108+
boto_session_mock.region_name = None
109+
sam_parameter_values = SamParameterValues({})
110+
with self.assertRaises(NoRegionFound):
111+
sam_parameter_values.add_pseudo_parameter_values(session=boto_session_mock)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from unittest import TestCase
2+
from parameterized import parameterized
3+
from mock import Mock, patch
4+
5+
from samtranslator.translator.arn_generator import ArnGenerator, NoRegionFound
6+
7+
8+
class TestArnGenerator(TestCase):
9+
def setUp(self):
10+
ArnGenerator.class_boto_session = None
11+
12+
@parameterized.expand(
13+
[("us-east-1", "aws"), ("cn-east-1", "aws-cn"), ("us-gov-west-1", "aws-us-gov"), ("US-EAST-1", "aws")]
14+
)
15+
def test_get_partition_name(self, region, expected):
16+
actual = ArnGenerator.get_partition_name(region)
17+
18+
self.assertEqual(actual, expected)
19+
20+
@patch("boto3.session.Session.region_name", None)
21+
def test_get_partition_name_raise_NoRegionFound(self):
22+
with self.assertRaises(NoRegionFound):
23+
ArnGenerator.get_partition_name(None)
24+
25+
def test_get_partition_name_from_boto_session(self):
26+
boto_session_mock = Mock()
27+
boto_session_mock.region_name = "us-east-1"
28+
29+
ArnGenerator.class_boto_session = boto_session_mock
30+
31+
actual = ArnGenerator.get_partition_name()
32+
33+
self.assertEqual(actual, "aws")
34+
35+
ArnGenerator.class_boto_session = None

0 commit comments

Comments
 (0)