Skip to content

feat(idempotency): Remove deadlock after lambda handler timeout #1198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion aws_lambda_powertools/utilities/idempotency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def _process_idempotency(self):
try:
# We call save_inprogress first as an optimization for the most common case where no idempotent record
# already exists. If it succeeds, there's no need to call get_record.
self.persistence_store.save_inprogress(data=self.data)
self.persistence_store.save_inprogress(
data=self.data,
remaining_time_in_seconds=self._get_remaining_time_in_seconds(),
)
except IdempotencyKeyError:
raise
except IdempotencyItemAlreadyExistsError:
Expand All @@ -113,6 +116,14 @@ def _process_idempotency(self):

return self._get_function_response()

def _get_remaining_time_in_seconds(self) -> Optional[int]:
"""
Try to get the time remaining in seconds from the lambda context
"""
if self.fn_args and len(self.fn_args) == 2 and getattr(self.fn_args[1], "get_remaining_time_in_millis", None):
return self.fn_args[1].get_remaining_time_in_millis() / 1000
return None

def _get_idempotency_record(self) -> DataRecord:
"""
Retrieve the idempotency record from the persistence layer.
Expand Down
4 changes: 4 additions & 0 deletions aws_lambda_powertools/utilities/idempotency/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(
jmespath_options: Optional[Dict] = None,
raise_on_no_idempotency_key: bool = False,
expires_after_seconds: int = 60 * 60, # 1 hour default
function_timeout_clean_up: bool = False,
use_local_cache: bool = False,
local_cache_max_items: int = 256,
hash_function: str = "md5",
Expand All @@ -26,6 +27,8 @@ def __init__(
Raise exception if no idempotency key was found in the request, by default False
expires_after_seconds: int
The number of seconds to wait before a record is expired
function_timeout_clean_up: bool
Whether to clean up "INPROGRESS" record after a function has timed out
use_local_cache: bool, optional
Whether to locally cache idempotency results, by default False
local_cache_max_items: int, optional
Expand All @@ -38,6 +41,7 @@ def __init__(
self.jmespath_options = jmespath_options
self.raise_on_no_idempotency_key = raise_on_no_idempotency_key
self.expires_after_seconds = expires_after_seconds
self.function_timeout_clean_up = function_timeout_clean_up
self.use_local_cache = use_local_cache
self.local_cache_max_items = local_cache_max_items
self.hash_function = hash_function
31 changes: 29 additions & 2 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
idempotency_key,
status: str = "",
expiry_timestamp: Optional[int] = None,
function_timeout: Optional[int] = None,
response_data: Optional[str] = "",
payload_hash: Optional[str] = None,
) -> None:
Expand All @@ -53,6 +54,8 @@ def __init__(
status of the idempotent record
expiry_timestamp: int, optional
time before the record should expire, in seconds
function_timeout: int, optional
time before the function should time out, in seconds
payload_hash: str, optional
hashed representation of payload
response_data: str, optional
Expand All @@ -61,6 +64,7 @@ def __init__(
self.idempotency_key = idempotency_key
self.payload_hash = payload_hash
self.expiry_timestamp = expiry_timestamp
self.function_timeout = function_timeout
self._status = status
self.response_data = response_data

Expand Down Expand Up @@ -120,6 +124,7 @@ def __init__(self):
self.validation_key_jmespath = None
self.raise_on_no_idempotency_key = False
self.expires_after_seconds: int = 60 * 60 # 1 hour default
self.function_timeout_clean_up = False
self.use_local_cache = False
self.hash_function = None

Expand Down Expand Up @@ -152,6 +157,7 @@ def configure(self, config: IdempotencyConfig, function_name: Optional[str] = No
self.payload_validation_enabled = True
self.raise_on_no_idempotency_key = config.raise_on_no_idempotency_key
self.expires_after_seconds = config.expires_after_seconds
self.function_timeout_clean_up = config.function_timeout_clean_up
self.use_local_cache = config.use_local_cache
if self.use_local_cache:
self._cache = LRUDict(max_items=config.local_cache_max_items)
Expand Down Expand Up @@ -257,9 +263,21 @@ def _get_expiry_timestamp(self) -> int:
int
unix timestamp of expiry date for idempotency record

"""
return self._get_timestamp_after_seconds(self.expires_after_seconds)

@staticmethod
def _get_timestamp_after_seconds(seconds: int) -> int:
"""

Returns
-------
int
unix timestamp after the specified seconds

"""
now = datetime.datetime.now()
period = datetime.timedelta(seconds=self.expires_after_seconds)
period = datetime.timedelta(seconds=seconds)
return int((now + period).timestamp())

def _save_to_cache(self, data_record: DataRecord):
Expand Down Expand Up @@ -317,6 +335,7 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None:
idempotency_key=self._get_hashed_idempotency_key(data=data),
status=STATUS_CONSTANTS["COMPLETED"],
expiry_timestamp=self._get_expiry_timestamp(),
function_timeout=None,
response_data=response_data,
payload_hash=self._get_hashed_payload(data=data),
)
Expand All @@ -328,19 +347,27 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None:

self._save_to_cache(data_record=data_record)

def save_inprogress(self, data: Dict[str, Any]) -> None:
def save_inprogress(self, data: Dict[str, Any], remaining_time_in_seconds: Optional[int] = None) -> None:
"""
Save record of function's execution being in progress

Parameters
----------
data: Dict[str, Any]
Payload
remaining_time_in_seconds: int, optional
Function remaining time in seconds
"""
function_timeout = (
self._get_timestamp_after_seconds(remaining_time_in_seconds)
if remaining_time_in_seconds and self.function_timeout_clean_up
else None
)
data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(data=data),
status=STATUS_CONSTANTS["INPROGRESS"],
expiry_timestamp=self._get_expiry_timestamp(),
function_timeout=function_timeout,
payload_hash=self._get_hashed_payload(data=data),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
static_pk_value: Optional[str] = None,
sort_key_attr: Optional[str] = None,
expiry_attr: str = "expiration",
function_timeout_attr: str = "function_timeout",
status_attr: str = "status",
data_attr: str = "data",
validation_key_attr: str = "validation",
Expand Down Expand Up @@ -85,6 +86,7 @@ def __init__(
self.static_pk_value = static_pk_value
self.sort_key_attr = sort_key_attr
self.expiry_attr = expiry_attr
self.function_timeout_attr = function_timeout_attr
self.status_attr = status_attr
self.data_attr = data_attr
self.validation_key_attr = validation_key_attr
Expand Down Expand Up @@ -133,6 +135,7 @@ def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord:
idempotency_key=item[self.key_attr],
status=item[self.status_attr],
expiry_timestamp=item[self.expiry_attr],
function_timeout=item.get(self.function_timeout_attr),
response_data=item.get(self.data_attr),
payload_hash=item.get(self.validation_key_attr),
)
Expand All @@ -150,6 +153,7 @@ def _put_record(self, data_record: DataRecord) -> None:
item = {
**self._get_key(data_record.idempotency_key),
self.expiry_attr: data_record.expiry_timestamp,
self.function_timeout_attr: data_record.function_timeout,
self.status_attr: data_record.status,
}

Expand All @@ -161,8 +165,12 @@ def _put_record(self, data_record: DataRecord) -> None:
logger.debug(f"Putting record for idempotency key: {data_record.idempotency_key}")
self.table.put_item(
Item=item,
ConditionExpression="attribute_not_exists(#id) OR #now < :now",
ExpressionAttributeNames={"#id": self.key_attr, "#now": self.expiry_attr},
ConditionExpression="attribute_not_exists(#id) OR #now < :now OR #function_timeout < :now",
ExpressionAttributeNames={
"#id": self.key_attr,
"#now": self.expiry_attr,
"#function_timeout": self.function_timeout_attr,
},
ExpressionAttributeValues={":now": int(now.timestamp())},
)
except self.table.meta.client.exceptions.ConditionalCheckFailedException:
Expand All @@ -171,15 +179,20 @@ def _put_record(self, data_record: DataRecord) -> None:

def _update_record(self, data_record: DataRecord):
logger.debug(f"Updating record for idempotency key: {data_record.idempotency_key}")
update_expression = "SET #response_data = :response_data, #expiry = :expiry, #status = :status"
update_expression = (
"SET #response_data = :response_data, #expiry = :expiry, "
"#function_timeout = :function_timeout, #status = :status"
)
expression_attr_values = {
":expiry": data_record.expiry_timestamp,
":function_timeout": data_record.function_timeout,
":response_data": data_record.response_data,
":status": data_record.status,
}
expression_attr_names = {
"#response_data": self.data_attr,
"#expiry": self.expiry_attr,
"#function_timeout": self.function_timeout_attr,
"#status": self.status_attr,
}

Expand Down
51 changes: 35 additions & 16 deletions tests/functional/idempotency/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import json
from collections import namedtuple
from decimal import Decimal
from unittest import mock

Expand Down Expand Up @@ -32,14 +31,19 @@ def lambda_apigw_event():

@pytest.fixture
def lambda_context():
lambda_context = {
"function_name": "test-func",
"memory_limit_in_mb": 128,
"invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241234:function:test-func",
"aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72",
}
class LambdaContext:
def __init__(self):
self.function_name = "test-func"
self.memory_limit_in_mb = 128
self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func"
self.aws_request_id = "52fdfc07-2182-154f-163f-5f0f9a621d72"

@staticmethod
def get_remaining_time_in_millis() -> int:
"""Returns the number of milliseconds left before the execution times out."""
return 3000

return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values())
return LambdaContext()


@pytest.fixture
Expand Down Expand Up @@ -77,15 +81,22 @@ def default_jmespath():
@pytest.fixture
def expected_params_update_item(serialized_lambda_response, hashed_idempotency_key):
return {
"ExpressionAttributeNames": {"#expiry": "expiration", "#response_data": "data", "#status": "status"},
"ExpressionAttributeNames": {
"#expiry": "expiration",
"#function_timeout": "function_timeout",
"#response_data": "data",
"#status": "status",
},
"ExpressionAttributeValues": {
":expiry": stub.ANY,
":function_timeout": None,
":response_data": serialized_lambda_response,
":status": "COMPLETED",
},
"Key": {"id": hashed_idempotency_key},
"TableName": "TEST_TABLE",
"UpdateExpression": "SET #response_data = :response_data, " "#expiry = :expiry, #status = :status",
"UpdateExpression": "SET #response_data = :response_data, "
"#expiry = :expiry, #function_timeout = :function_timeout, #status = :status",
}


Expand All @@ -96,46 +107,54 @@ def expected_params_update_item_with_validation(
return {
"ExpressionAttributeNames": {
"#expiry": "expiration",
"#function_timeout": "function_timeout",
"#response_data": "data",
"#status": "status",
"#validation_key": "validation",
},
"ExpressionAttributeValues": {
":expiry": stub.ANY,
":function_timeout": None,
":response_data": serialized_lambda_response,
":status": "COMPLETED",
":validation_key": hashed_validation_key,
},
"Key": {"id": hashed_idempotency_key},
"TableName": "TEST_TABLE",
"UpdateExpression": "SET #response_data = :response_data, "
"#expiry = :expiry, #status = :status, "
"#expiry = :expiry, #function_timeout = :function_timeout, #status = :status, "
"#validation_key = :validation_key",
}


@pytest.fixture
def expected_params_put_item(hashed_idempotency_key):
return {
"ConditionExpression": "attribute_not_exists(#id) OR #now < :now",
"ExpressionAttributeNames": {"#id": "id", "#now": "expiration"},
"ConditionExpression": "attribute_not_exists(#id) OR #now < :now OR #function_timeout < :now",
"ExpressionAttributeNames": {"#id": "id", "#now": "expiration", "#function_timeout": "function_timeout"},
"ExpressionAttributeValues": {":now": stub.ANY},
"Item": {"expiration": stub.ANY, "id": hashed_idempotency_key, "status": "INPROGRESS"},
"Item": {
"expiration": stub.ANY,
"id": hashed_idempotency_key,
"status": "INPROGRESS",
"function_timeout": None,
},
"TableName": "TEST_TABLE",
}


@pytest.fixture
def expected_params_put_item_with_validation(hashed_idempotency_key, hashed_validation_key):
return {
"ConditionExpression": "attribute_not_exists(#id) OR #now < :now",
"ExpressionAttributeNames": {"#id": "id", "#now": "expiration"},
"ConditionExpression": "attribute_not_exists(#id) OR #now < :now OR #function_timeout < :now",
"ExpressionAttributeNames": {"#id": "id", "#now": "expiration", "#function_timeout": "function_timeout"},
"ExpressionAttributeValues": {":now": stub.ANY},
"Item": {
"expiration": stub.ANY,
"id": hashed_idempotency_key,
"status": "INPROGRESS",
"validation": hashed_validation_key,
"function_timeout": None,
},
"TableName": "TEST_TABLE",
}
Expand Down
44 changes: 44 additions & 0 deletions tests/functional/idempotency/test_idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,3 +1256,47 @@ def lambda_handler(event, context):

stubber.assert_no_pending_responses()
stubber.deactivate()


def test_idempotent_lambda_cleanup(
persistence_store: DynamoDBPersistenceLayer,
hashed_idempotency_key,
lambda_apigw_event,
expected_params_update_item,
lambda_response,
lambda_context,
):
# GIVEN
idempotency_config = IdempotencyConfig(
function_timeout_clean_up=True,
event_key_jmespath="[body, queryStringParameters]",
)

stubber = stub.Stubber(persistence_store.table.meta.client)
expected_params_put_item = {
"ConditionExpression": "attribute_not_exists(#id) OR #now < :now OR #function_timeout < :now",
"ExpressionAttributeNames": {"#id": "id", "#now": "expiration", "#function_timeout": "function_timeout"},
"ExpressionAttributeValues": {":now": stub.ANY},
"Item": {
"expiration": stub.ANY,
"id": hashed_idempotency_key,
"status": "INPROGRESS",
"function_timeout": stub.ANY,
},
"TableName": "TEST_TABLE",
}
ddb_response = {}
stubber.add_response("put_item", ddb_response, expected_params_put_item)
stubber.add_response("update_item", ddb_response, expected_params_update_item)
stubber.activate()

@idempotent(config=idempotency_config, persistence_store=persistence_store)
def lambda_handler(event, context):
return lambda_response

# WHEN
lambda_handler(lambda_apigw_event, lambda_context)

# THEN
stubber.assert_no_pending_responses()
stubber.deactivate()
Loading