Skip to content

Commit 14a7ad9

Browse files
Bug Fix: Defensively copy context entities (#340)
Co-authored-by: Nathaniel Ruiz Nowell <[email protected]>
1 parent 0f13101 commit 14a7ad9

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

aws_xray_sdk/core/async_context.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import sys
3+
import copy
34

45
from .context import Context as _Context
56

@@ -108,6 +109,18 @@ def task_factory(loop, coro):
108109
else:
109110
current_task = asyncio.Task.current_task(loop=loop)
110111
if current_task is not None and hasattr(current_task, 'context'):
111-
setattr(task, 'context', current_task.context)
112+
if current_task.context.get('entities'):
113+
# NOTE: (enowell) Because the `AWSXRayRecorder`'s `Context` decides
114+
# the parent by looking at its `_local.entities`, we must copy the entities
115+
# for concurrent subsegments. Otherwise, the subsegments would be
116+
# modifying the same `entities` list and sugsegments would take other
117+
# subsegments as parents instead of the original `segment`.
118+
#
119+
# See more: https://github.com/aws/aws-xray-sdk-python/blob/0f13101e4dba7b5c735371cb922f727b1d9f46d8/aws_xray_sdk/core/context.py#L90-L101
120+
new_context = copy.copy(current_task.context)
121+
new_context['entities'] = [item for item in current_task.context['entities']]
122+
else:
123+
new_context = current_task.context
124+
setattr(task, 'context', new_context)
112125

113126
return task

tests/test_async_recorder.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .util import get_new_stubbed_recorder
44
from aws_xray_sdk.version import VERSION
55
from aws_xray_sdk.core.async_context import AsyncContext
6+
import asyncio
67

78

89
xray_recorder = get_new_stubbed_recorder()
@@ -43,6 +44,28 @@ async def test_capture(loop):
4344
assert platform.python_implementation() == service.get('runtime')
4445
assert platform.python_version() == service.get('runtime_version')
4546

47+
async def test_concurrent_calls(loop):
48+
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))
49+
async with xray_recorder.in_segment_async('segment') as segment:
50+
global counter
51+
counter = 0
52+
total_tasks = 10
53+
flag = asyncio.Event()
54+
async def assert_task():
55+
async with xray_recorder.in_subsegment_async('segment') as subsegment:
56+
global counter
57+
counter += 1
58+
# Begin all subsegments before closing any to ensure they overlap
59+
if counter < total_tasks:
60+
await flag.wait()
61+
else:
62+
flag.set()
63+
return subsegment.parent_id
64+
tasks = [assert_task() for task in range(total_tasks)]
65+
subsegs_parent_ids = await asyncio.gather(*tasks)
66+
for subseg_parent_id in subsegs_parent_ids:
67+
assert subseg_parent_id == segment.id
68+
4669

4770
async def test_async_context_managers(loop):
4871
xray_recorder.configure(service='test', sampling=False, context=AsyncContext(loop=loop))

0 commit comments

Comments
 (0)