diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 1122bea4c03..bc089347c8c 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -15,12 +15,14 @@ from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.data_classes.sns_event import SNSEventRecord logger = logging.getLogger(__name__) class EventType(Enum): SQS = "SQS" + SNS = "SNS" KinesisDataStreams = "KinesisDataStreams" DynamoDBStreams = "DynamoDBStreams" @@ -330,11 +332,13 @@ def __init__(self, event_type: EventType, model: Optional["BatchTypeModels"] = N self.batch_response = copy.deepcopy(self.DEFAULT_RESPONSE) self._COLLECTOR_MAPPING = { EventType.SQS: self._collect_sqs_failures, + EventType.SNS: self._collect_sns_failures, EventType.KinesisDataStreams: self._collect_kinesis_failures, EventType.DynamoDBStreams: self._collect_dynamodb_failures, } self._DATA_CLASS_MAPPING = { EventType.SQS: SQSRecord, + EventType.SNS: SNSEventRecord, EventType.KinesisDataStreams: KinesisStreamRecord, EventType.DynamoDBStreams: DynamoDBRecord, } @@ -413,6 +417,13 @@ def _collect_sqs_failures(self): failures.append({"itemIdentifier": msg_id}) return failures + def _collect_sns_failures(self): + failures = [] + for msg in self.fail_messages: + msg_id = msg.sns.MessageId if self.model else msg.sns.message_id + failures.append({"itemIdentifier": msg_id}) + return failures + def _collect_kinesis_failures(self): failures = [] for msg in self.fail_messages: diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index a5e1e706437..56c4307ca8a 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -50,6 +50,34 @@ def factory(body: str): return factory +@pytest.fixture(scope="module") +def sns_event_factory() -> Callable: + def factory(body: str): + return { + "EventVersion": "1.0", + "EventSubscriptionArn": "arn:aws:sns:us-east-2:123456789012:sns-la ...", + "EventSource": "aws:sns", + "Sns": { + "SignatureVersion": "1", + "Timestamp": "2019-01-02T12:45:07.000Z", + "Signature": "tcc6faL2yUC6dgZdmrwh1Y4cGa/ebXEkAi6RibDsvpi+tE/1+82j...65r==", + "SigningCertUrl": "https://sns.us-east-2.amazonaws.com/SimpleNotification", + "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", + "Message": "Hello from SNS!", + "MessageAttributes": { + "Test": {"Type": "String", "Value": "TestString"}, + "TestBinary": {"Type": "Binary", "Value": "TestBinary"}, + }, + "Type": "Notification", + "UnsubscribeUrl": "https://sns.us-east-2.amazonaws.com/?Action=Unsubscribe", + "TopicArn": "arn:aws:sns:us-east-2:123456789012:sns-lambda", + "Subject": "TestInvoke", + }, + } + + return factory + + @pytest.fixture(scope="module") def kinesis_event_factory() -> Callable: def factory(body: str):