Skip to content

Commit eeba30b

Browse files
committed
fix(telemetry): group traces when using agent as tool in an agent
1 parent 6638fb0 commit eeba30b

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

src/strands/agent/agent.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from concurrent.futures import ThreadPoolExecutor
1717
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
1818

19-
from opentelemetry import trace
19+
from opentelemetry import trace as trace_api
2020
from pydantic import BaseModel
2121

2222
from ..event_loop.event_loop import event_loop_cycle, run_tool
@@ -300,7 +300,7 @@ def __init__(
300300

301301
# Initialize tracer instance (no-op if not configured)
302302
self.tracer = get_tracer()
303-
self.trace_span: Optional[trace.Span] = None
303+
self.trace_span: Optional[trace_api.Span] = None
304304

305305
# Initialize agent state management
306306
if state is not None:
@@ -504,17 +504,17 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
504504
message: Message = {"role": "user", "content": content}
505505

506506
self._start_agent_trace_span(message)
507-
508507
try:
509-
events = self._run_loop(message, invocation_state=kwargs)
510-
async for event in events:
511-
if "callback" in event:
512-
callback_handler(**event["callback"])
513-
yield event["callback"]
508+
with trace_api.use_span(self.trace_span):
509+
events = self._run_loop(message, invocation_state=kwargs)
510+
async for event in events:
511+
if "callback" in event:
512+
callback_handler(**event["callback"])
513+
yield event["callback"]
514514

515-
result = AgentResult(*event["stop"])
516-
callback_handler(result=result)
517-
yield {"result": result}
515+
result = AgentResult(*event["stop"])
516+
callback_handler(result=result)
517+
yield {"result": result}
518518

519519
self._end_agent_trace_span(response=result)
520520

@@ -659,7 +659,6 @@ def _start_agent_trace_span(self, message: Message) -> None:
659659
message: The user message.
660660
"""
661661
model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
662-
663662
self.trace_span = self.tracer.start_agent_span(
664663
message=message,
665664
agent_name=self.name,

0 commit comments

Comments
 (0)