diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 956c246ce..7c23b9ab8 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast -from opentelemetry import trace +from opentelemetry import trace as trace_api from pydantic import BaseModel from ..event_loop.event_loop import event_loop_cycle, run_tool @@ -300,7 +300,7 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() - self.trace_span: Optional[trace.Span] = None + self.trace_span: Optional[trace_api.Span] = None # Initialize agent state management if state is not None: @@ -503,24 +503,24 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A content: list[ContentBlock] = [{"text": prompt}] if isinstance(prompt, str) else prompt message: Message = {"role": "user", "content": content} - self._start_agent_trace_span(message) + self.trace_span = self._start_agent_trace_span(message) + with trace_api.use_span(self.trace_span): + try: + events = self._run_loop(message, invocation_state=kwargs) + async for event in events: + if "callback" in event: + callback_handler(**event["callback"]) + yield event["callback"] - try: - events = self._run_loop(message, invocation_state=kwargs) - async for event in events: - if "callback" in event: - callback_handler(**event["callback"]) - yield event["callback"] + result = AgentResult(*event["stop"]) + callback_handler(result=result) + yield {"result": result} - result = AgentResult(*event["stop"]) - callback_handler(result=result) - yield {"result": result} + self._end_agent_trace_span(response=result) - self._end_agent_trace_span(response=result) - - except Exception as e: - self._end_agent_trace_span(error=e) - raise + except Exception as e: + self._end_agent_trace_span(error=e) + raise async def _run_loop( self, message: Message, invocation_state: dict[str, Any] @@ -652,15 +652,14 @@ def _record_tool_execution( self._append_message(tool_result_msg) self._append_message(assistant_msg) - def _start_agent_trace_span(self, message: Message) -> None: + def _start_agent_trace_span(self, message: Message) -> trace_api.Span: """Starts a trace span for the agent. Args: message: The user message. """ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None - - self.trace_span = self.tracer.start_agent_span( + return self.tracer.start_agent_span( message=message, agent_name=self.name, model_id=model_id, diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b6ed6a975..ffcb6a5c9 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -13,6 +13,8 @@ import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from opentelemetry import trace as trace_api + from ..experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, @@ -114,72 +116,75 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> parent_span=cycle_span, model_id=model_id, ) - - tool_specs = agent.tool_registry.get_all_tool_specs() - - agent.hooks.invoke_callbacks( - BeforeModelInvocationEvent( - agent=agent, - ) - ) - - try: - # TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state - # before yielding to the callback handler. This will be revisited when migrating to strongly - # typed events. - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): - if "callback" in event: - yield { - "callback": {**event["callback"], **(invocation_state if "delta" in event["callback"] else {})} - } - - stop_reason, message, usage, metrics = event["stop"] - invocation_state.setdefault("request_state", {}) + with trace_api.use_span(model_invoke_span): + tool_specs = agent.tool_registry.get_all_tool_specs() agent.hooks.invoke_callbacks( - AfterModelInvocationEvent( + BeforeModelInvocationEvent( agent=agent, - stop_response=AfterModelInvocationEvent.ModelStopResponse( - stop_reason=stop_reason, - message=message, - ), ) ) - if model_invoke_span: - tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) - break # Success! Break out of retry loop - - except Exception as e: - if model_invoke_span: - tracer.end_span_with_error(model_invoke_span, str(e), e) - - agent.hooks.invoke_callbacks( - AfterModelInvocationEvent( - agent=agent, - exception=e, + try: + # TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state + # before yielding to the callback handler. This will be revisited when migrating to strongly + # typed events. + async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): + if "callback" in event: + yield { + "callback": { + **event["callback"], + **(invocation_state if "delta" in event["callback"] else {}), + } + } + + stop_reason, message, usage, metrics = event["stop"] + invocation_state.setdefault("request_state", {}) + + agent.hooks.invoke_callbacks( + AfterModelInvocationEvent( + agent=agent, + stop_response=AfterModelInvocationEvent.ModelStopResponse( + stop_reason=stop_reason, + message=message, + ), + ) ) - ) - if isinstance(e, ModelThrottledException): - if attempt + 1 == MAX_ATTEMPTS: - yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} - raise e + if model_invoke_span: + tracer.end_model_invoke_span(model_invoke_span, message, usage, stop_reason) + break # Success! Break out of retry loop - logger.debug( - "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " - "| throttling exception encountered " - "| delaying before next retry", - current_delay, - MAX_ATTEMPTS, - attempt + 1, + except Exception as e: + if model_invoke_span: + tracer.end_span_with_error(model_invoke_span, str(e), e) + + agent.hooks.invoke_callbacks( + AfterModelInvocationEvent( + agent=agent, + exception=e, + ) ) - time.sleep(current_delay) - current_delay = min(current_delay * 2, MAX_DELAY) - yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}} - else: - raise e + if isinstance(e, ModelThrottledException): + if attempt + 1 == MAX_ATTEMPTS: + yield {"callback": {"force_stop": True, "force_stop_reason": str(e)}} + raise e + + logger.debug( + "retry_delay_seconds=<%s>, max_attempts=<%s>, current_attempt=<%s> " + "| throttling exception encountered " + "| delaying before next retry", + current_delay, + MAX_ATTEMPTS, + attempt + 1, + ) + time.sleep(current_delay) + current_delay = min(current_delay * 2, MAX_DELAY) + + yield {"callback": {"event_loop_throttled_delay": current_delay, **invocation_state}} + else: + raise e try: # Add message in trace and mark the end of the stream messages trace diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index f060c7f6e..eebffef29 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -213,7 +213,7 @@ def start_model_invoke_span( parent_span: Optional[Span] = None, model_id: Optional[str] = None, **kwargs: Any, - ) -> Optional[Span]: + ) -> Span: """Start a new span for a model invocation. Args: @@ -414,7 +414,7 @@ def start_agent_span( tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, **kwargs: Any, - ) -> Optional[Span]: + ) -> Span: """Start a new span for an agent invocation. Args: