From 016ce9156f24a93ed197e190aaa5f6e7ae08731d Mon Sep 17 00:00:00 2001 From: Tyler Dodge Date: Thu, 9 Jun 2022 09:18:42 -0700 Subject: [PATCH 1/3] Bug Fix: Defensively copy context entities Before this change, concurrent async tasks would all share the same instance of the entities list. This change makes it so they each get their own copy of the list. This matters because the recorder modifies the list in place, which makes it so concurrent subtasks have the wrong parent subsegment. --- aws_xray_sdk/core/async_context.py | 10 +++++++++- tests/test_async_recorder.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/aws_xray_sdk/core/async_context.py b/aws_xray_sdk/core/async_context.py index b287a42f..97624572 100644 --- a/aws_xray_sdk/core/async_context.py +++ b/aws_xray_sdk/core/async_context.py @@ -1,5 +1,6 @@ import asyncio import sys +import copy from .context import Context as _Context @@ -108,6 +109,13 @@ def task_factory(loop, coro): else: current_task = asyncio.Task.current_task(loop=loop) if current_task is not None and hasattr(current_task, 'context'): - setattr(task, 'context', current_task.context) + if current_task.context.get('entities'): + # Defensively copying because recorder modifies the list in place. + new_context = copy.copy(current_task.context) + new_context['entities'] = [item for item in current_task.context['entities']] + else: + # no reason to copy if there's no entities list. + new_context = current_task.context + setattr(task, 'context', new_context) return task diff --git a/tests/test_async_recorder.py b/tests/test_async_recorder.py index eba147f7..7d4cd27c 100644 --- a/tests/test_async_recorder.py +++ b/tests/test_async_recorder.py @@ -3,6 +3,7 @@ from .util import get_new_stubbed_recorder from aws_xray_sdk.version import VERSION from aws_xray_sdk.core.async_context import AsyncContext +import asyncio xray_recorder = get_new_stubbed_recorder() @@ -43,6 +44,28 @@ async def test_capture(loop): assert platform.python_implementation() == service.get('runtime') assert platform.python_version() == service.get('runtime_version') +async def test_concurrent_calls(loop): + xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop)) + async with xray_recorder.in_segment_async('segment') as segment: + global counter + counter = 0 + total_tasks = 10 + event = asyncio.Event() + async def assert_task(): + async with xray_recorder.in_subsegment_async('segment') as subsegment: + global counter + counter += 1 + # Ensure that the task subsegments overlap + if counter < total_tasks: + await event.wait() + else: + event.set() + return subsegment.parent_id + tasks = [assert_task() for task in range(total_tasks)] + results = await asyncio.gather(*tasks) + for result in results: + assert result == segment.id + async def test_async_context_managers(loop): xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop)) From 05e681b6c168f4ed8bcd4b11b71ae61b7f305c3e Mon Sep 17 00:00:00 2001 From: Tyler Dodge Date: Mon, 27 Jun 2022 12:00:51 -0700 Subject: [PATCH 2/3] Update tests/test_async_recorder.py Co-authored-by: Nathaniel Ruiz Nowell --- tests/test_async_recorder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_async_recorder.py b/tests/test_async_recorder.py index 7d4cd27c..f849448a 100644 --- a/tests/test_async_recorder.py +++ b/tests/test_async_recorder.py @@ -62,9 +62,9 @@ async def assert_task(): event.set() return subsegment.parent_id tasks = [assert_task() for task in range(total_tasks)] - results = await asyncio.gather(*tasks) - for result in results: - assert result == segment.id + subsegs_parent_ids = await asyncio.gather(*tasks) + for subseg_parent_id in subsegs_parent_ids: + assert subseg_parent_id == segment.id async def test_async_context_managers(loop): From 5d918ab9e2bb4561d37c08b56071ab6581acea2a Mon Sep 17 00:00:00 2001 From: Nathaniel Ruiz Nowell Date: Mon, 27 Jun 2022 12:44:11 -0700 Subject: [PATCH 3/3] Comments and naming help explain bufix solution --- aws_xray_sdk/core/async_context.py | 9 +++++++-- tests/test_async_recorder.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/aws_xray_sdk/core/async_context.py b/aws_xray_sdk/core/async_context.py index 97624572..acba00e2 100644 --- a/aws_xray_sdk/core/async_context.py +++ b/aws_xray_sdk/core/async_context.py @@ -110,11 +110,16 @@ def task_factory(loop, coro): current_task = asyncio.Task.current_task(loop=loop) if current_task is not None and hasattr(current_task, 'context'): if current_task.context.get('entities'): - # Defensively copying because recorder modifies the list in place. + # NOTE: (enowell) Because the `AWSXRayRecorder`'s `Context` decides + # the parent by looking at its `_local.entities`, we must copy the entities + # for concurrent subsegments. Otherwise, the subsegments would be + # modifying the same `entities` list and sugsegments would take other + # subsegments as parents instead of the original `segment`. + # + # See more: https://github.com/aws/aws-xray-sdk-python/blob/0f13101e4dba7b5c735371cb922f727b1d9f46d8/aws_xray_sdk/core/context.py#L90-L101 new_context = copy.copy(current_task.context) new_context['entities'] = [item for item in current_task.context['entities']] else: - # no reason to copy if there's no entities list. new_context = current_task.context setattr(task, 'context', new_context) diff --git a/tests/test_async_recorder.py b/tests/test_async_recorder.py index f849448a..0367fb3c 100644 --- a/tests/test_async_recorder.py +++ b/tests/test_async_recorder.py @@ -50,16 +50,16 @@ async def test_concurrent_calls(loop): global counter counter = 0 total_tasks = 10 - event = asyncio.Event() + flag = asyncio.Event() async def assert_task(): async with xray_recorder.in_subsegment_async('segment') as subsegment: global counter counter += 1 - # Ensure that the task subsegments overlap + # Begin all subsegments before closing any to ensure they overlap if counter < total_tasks: - await event.wait() + await flag.wait() else: - event.set() + flag.set() return subsegment.parent_id tasks = [assert_task() for task in range(total_tasks)] subsegs_parent_ids = await asyncio.gather(*tasks)