From 0307a3d0c3e384ff30aef11ca415bc81d843770a Mon Sep 17 00:00:00 2001 From: Max Rabin Date: Mon, 15 Dec 2025 13:39:19 +0200 Subject: [PATCH] Update ruff configuration to apply pyupgrade to modernize python syntax --- pyproject.toml | 1 + src/strands/_async.py | 3 +- src/strands/agent/agent.py | 59 ++++++------- src/strands/agent/agent_result.py | 3 +- .../conversation_manager.py | 6 +- .../null_conversation_manager.py | 4 +- .../sliding_window_conversation_manager.py | 6 +- .../summarizing_conversation_manager.py | 14 +-- src/strands/event_loop/event_loop.py | 3 +- src/strands/event_loop/streaming.py | 9 +- src/strands/experimental/agent_config.py | 2 +- .../steering/handlers/llm/llm_handler.py | 2 +- .../experimental/tools/tool_provider.py | 3 +- src/strands/hooks/events.py | 12 +-- src/strands/hooks/registry.py | 7 +- src/strands/models/_validation.py | 5 +- src/strands/models/anthropic.py | 19 +++-- src/strands/models/bedrock.py | 75 ++++++++-------- src/strands/models/gemini.py | 27 +++--- src/strands/models/litellm.py | 29 ++++--- src/strands/models/llamaapi.py | 27 +++--- src/strands/models/llamacpp.py | 36 ++++---- src/strands/models/mistral.py | 29 ++++--- src/strands/models/model.py | 11 +-- src/strands/models/ollama.py | 33 +++---- src/strands/models/openai.py | 23 ++--- src/strands/models/sagemaker.py | 45 +++++----- src/strands/models/writer.py | 27 +++--- src/strands/multiagent/a2a/executor.py | 6 +- src/strands/multiagent/base.py | 5 +- src/strands/multiagent/graph.py | 35 ++++---- src/strands/multiagent/swarm.py | 11 +-- src/strands/session/file_session_manager.py | 16 ++-- .../session/repository_session_manager.py | 4 +- src/strands/session/s3_session_manager.py | 26 +++--- src/strands/session/session_repository.py | 12 +-- src/strands/telemetry/metrics.py | 53 ++++++------ src/strands/telemetry/tracer.py | 85 +++++++++---------- src/strands/tools/_caller.py | 3 +- src/strands/tools/decorator.py | 25 +++--- src/strands/tools/executors/_executor.py | 3 +- src/strands/tools/executors/concurrent.py | 3 +- src/strands/tools/executors/sequential.py | 3 +- src/strands/tools/loader.py | 12 +-- src/strands/tools/mcp/mcp_client.py | 18 ++-- src/strands/tools/mcp/mcp_instrumentation.py | 9 +- src/strands/tools/mcp/mcp_types.py | 4 +- src/strands/tools/registry.py | 29 ++++--- .../_structured_output_context.py | 10 +-- .../structured_output_tool.py | 10 +-- .../structured_output_utils.py | 28 +++--- src/strands/tools/watcher.py | 6 +- src/strands/types/_events.py | 3 +- src/strands/types/citations.py | 10 +-- src/strands/types/collections.py | 4 +- src/strands/types/content.py | 14 +-- src/strands/types/guardrails.py | 24 +++--- src/strands/types/media.py | 6 +- src/strands/types/session.py | 4 +- src/strands/types/streaming.py | 22 +++-- src/strands/types/tools.py | 16 +--- src/strands/types/traces.py | 33 ++++--- tests/fixtures/mock_hook_provider.py | 7 +- .../fixtures/mock_multiagent_hook_provider.py | 7 +- tests/fixtures/mocked_model_provider.py | 19 +++-- .../strands/agent/hooks/test_hook_registry.py | 3 +- tests/strands/agent/test_agent_result.py | 4 +- .../agent/test_agent_structured_output.py | 3 +- tests/strands/models/test_sagemaker.py | 16 ++-- tests/strands/models/test_writer.py | 6 +- .../session/test_file_session_manager.py | 8 +- .../test_structured_output_context.py | 4 +- .../test_structured_output_tool.py | 5 +- tests/strands/tools/test_decorator.py | 23 ++--- tests/strands/tools/test_structured_output.py | 20 ++--- tests_integ/mcp/echo_server.py | 4 +- tests_integ/mcp/test_mcp_client.py | 6 +- tests_integ/models/providers.py | 4 +- tests_integ/test_function_tools.py | 3 +- tests_integ/test_multiagent_graph.py | 3 +- .../test_structured_output_agent_loop.py | 12 ++- 81 files changed, 612 insertions(+), 617 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c2a6b260..489dfc0e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ select = [ "G", # logging format "I", # isort "LOG", # logging + "UP", # pyupgrade ] [tool.ruff.lint.per-file-ignores] diff --git a/src/strands/_async.py b/src/strands/_async.py index 141ca71b7..0ceb038f3 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -2,8 +2,9 @@ import asyncio import contextvars +from collections.abc import Awaitable, Callable from concurrent.futures import ThreadPoolExecutor -from typing import Awaitable, Callable, TypeVar +from typing import TypeVar T = TypeVar("T") diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8fc5be6ca..69a9ebb2d 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -11,15 +11,10 @@ import logging import warnings +from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, - AsyncIterator, - Callable, - Mapping, - Optional, - Type, TypeVar, Union, cast, @@ -104,26 +99,24 @@ class Agent: def __init__( self, - model: Union[Model, str, None] = None, - messages: Optional[Messages] = None, - tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None, - system_prompt: Optional[str | list[SystemContentBlock]] = None, - structured_output_model: Optional[Type[BaseModel]] = None, - callback_handler: Optional[ - Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] - ] = _DEFAULT_CALLBACK_HANDLER, - conversation_manager: Optional[ConversationManager] = None, + model: Model | str | None = None, + messages: Messages | None = None, + tools: list[Union[str, dict[str, str], "ToolProvider", Any]] | None = None, + system_prompt: str | list[SystemContentBlock] | None = None, + structured_output_model: type[BaseModel] | None = None, + callback_handler: Callable[..., Any] | _DefaultCallbackHandlerSentinel | None = _DEFAULT_CALLBACK_HANDLER, + conversation_manager: ConversationManager | None = None, record_direct_tool_call: bool = True, load_tools_from_directory: bool = False, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, *, - agent_id: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - state: Optional[Union[AgentState, dict]] = None, - hooks: Optional[list[HookProvider]] = None, - session_manager: Optional[SessionManager] = None, - tool_executor: Optional[ToolExecutor] = None, + agent_id: str | None = None, + name: str | None = None, + description: str | None = None, + state: AgentState | dict | None = None, + hooks: list[HookProvider] | None = None, + session_manager: SessionManager | None = None, + tool_executor: ToolExecutor | None = None, ): """Initialize the Agent with the specified configuration. @@ -189,7 +182,7 @@ def __init__( # If not provided, create a new PrintingCallbackHandler instance # If explicitly set to None, use null_callback_handler # Otherwise use the passed callback_handler - self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + self.callback_handler: Callable[..., Any] | PrintingCallbackHandler if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): self.callback_handler = PrintingCallbackHandler() elif callback_handler is None: @@ -226,7 +219,7 @@ def __init__( # Initialize tracer instance (no-op if not configured) self.tracer = get_tracer() - self.trace_span: Optional[trace_api.Span] = None + self.trace_span: trace_api.Span | None = None # Initialize agent state management if state is not None: @@ -316,7 +309,7 @@ def __call__( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -357,7 +350,7 @@ async def invoke_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -394,7 +387,7 @@ async def invoke_async( return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: + def structured_output(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -425,7 +418,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> return run_async(lambda: self.structured_output_async(output_model, prompt)) - async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: + async def structured_output_async(self, output_model: type[T], prompt: AgentInput = None) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -520,7 +513,7 @@ async def stream_async( prompt: AgentInput = None, *, invocation_state: dict[str, Any] | None = None, - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -607,7 +600,7 @@ async def _run_loop( self, messages: Messages, invocation_state: dict[str, Any], - structured_output_model: Type[BaseModel] | None = None, + structured_output_model: type[BaseModel] | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -769,8 +762,8 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index ef8a11029..2ab95e5b5 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -3,8 +3,9 @@ This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle. """ +from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, Sequence, cast +from typing import Any, cast from pydantic import BaseModel diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 2c1ee7847..ce460b1ef 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,7 +1,7 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ...types.content import Message @@ -30,7 +30,7 @@ def __init__(self) -> None: """ self.removed_message_count = 0 - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restore the Conversation Manager's state from a session. Args: @@ -66,7 +66,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: pass @abstractmethod - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 5ff6874e5..11632525d 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,6 +1,6 @@ """Null implementation of conversation management.""" -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent @@ -28,7 +28,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Does not reduce context and raises an exception. Args: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index e082abe8e..629c76f5f 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,7 +1,7 @@ """Sliding window conversation history management.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...agent.agent import Agent @@ -52,7 +52,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: return self.reduce_context(agent) - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: @@ -151,7 +151,7 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool: return changes_made - def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]: + def _find_last_message_with_tool_results(self, messages: Messages) -> int | None: """Find the index of the last message containing tool results. This is useful for identifying messages that might need to be truncated to reduce context size. diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index 12185c286..cc71e4d88 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -1,7 +1,7 @@ """Summarizing conversation history management with configurable options.""" import logging -from typing import TYPE_CHECKING, Any, List, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from typing_extensions import override @@ -62,7 +62,7 @@ def __init__( summary_ratio: float = 0.3, preserve_recent_messages: int = 10, summarization_agent: Optional["Agent"] = None, - summarization_system_prompt: Optional[str] = None, + summarization_system_prompt: str | None = None, ): """Initialize the summarizing conversation manager. @@ -87,10 +87,10 @@ def __init__( self.preserve_recent_messages = preserve_recent_messages self.summarization_agent = summarization_agent self.summarization_system_prompt = summarization_system_prompt - self._summary_message: Optional[Message] = None + self._summary_message: Message | None = None @override - def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: """Restores the Summarizing Conversation manager from its previous state in a session. Args: @@ -121,7 +121,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: # No proactive management - summarization only happens on context overflow pass - def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None: + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. Args: @@ -173,7 +173,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs logger.error("Summarization failed: %s", summarization_error) raise summarization_error from e - def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: + def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message: """Generate a summary of the provided messages. Args: @@ -224,7 +224,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message: summarization_agent.messages = original_messages summarization_agent.tool_registry = original_tool_registry - def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int: + def _adjust_split_point_for_tool_pairs(self, messages: list[Message], split_point: int) -> int: """Adjust the split point to avoid breaking ToolUse/ToolResult pairs. Uses the same logic as SlidingWindowConversationManager for consistency. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index f25057e4d..49d4f7b50 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,7 +11,8 @@ import asyncio import logging import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from opentelemetry import trace as trace_api diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 43836fe34..0a01fd6c2 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -4,7 +4,8 @@ import logging import time import warnings -from typing import Any, AsyncGenerator, AsyncIterable, Optional +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any from ..models.model import Model from ..tools import InvalidToolUseNameException @@ -418,12 +419,12 @@ async def process_stream( async def stream_messages( model: Model, - system_prompt: Optional[str], + system_prompt: str | None, messages: Messages, tool_specs: list[ToolSpec], *, - tool_choice: Optional[Any] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_choice: Any | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[TypedEvent, None]: """Streams messages to the model and processes the response. diff --git a/src/strands/experimental/agent_config.py b/src/strands/experimental/agent_config.py index f65afb57d..e6fb94118 100644 --- a/src/strands/experimental/agent_config.py +++ b/src/strands/experimental/agent_config.py @@ -98,7 +98,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A if not config_path.exists(): raise FileNotFoundError(f"Configuration file not found: {file_path}") - with open(config_path, "r") as f: + with open(config_path) as f: config_dict = json.load(f) elif isinstance(config, dict): config_dict = config.copy() diff --git a/src/strands/experimental/steering/handlers/llm/llm_handler.py b/src/strands/experimental/steering/handlers/llm/llm_handler.py index 9d9b34911..4d90f46c9 100644 --- a/src/strands/experimental/steering/handlers/llm/llm_handler.py +++ b/src/strands/experimental/steering/handlers/llm/llm_handler.py @@ -58,7 +58,7 @@ def __init__( self.prompt_mapper = prompt_mapper or DefaultPromptMapper() self.model = model - async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction: + async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> SteeringAction: """Provide contextual guidance for tool usage. Args: diff --git a/src/strands/experimental/tools/tool_provider.py b/src/strands/experimental/tools/tool_provider.py index 2c79ceafc..c40d1b572 100644 --- a/src/strands/experimental/tools/tool_provider.py +++ b/src/strands/experimental/tools/tool_provider.py @@ -1,7 +1,8 @@ """Tool provider interface.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ...types.tools import AgentTool diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index ebc508f24..cdd53f2f1 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -5,7 +5,7 @@ import uuid from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from typing_extensions import override @@ -116,7 +116,7 @@ class BeforeToolCallEvent(HookEvent, _Interruptible): the tool call and use a default cancel message. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] cancel_tool: bool | str = False @@ -157,11 +157,11 @@ class AfterToolCallEvent(HookEvent): cancel_message: The cancellation message if the user cancelled the tool call. """ - selected_tool: Optional[AgentTool] + selected_tool: AgentTool | None tool_use: ToolUse invocation_state: dict[str, Any] result: ToolResult - exception: Optional[Exception] = None + exception: Exception | None = None cancel_message: str | None = None def _can_write(self, name: str) -> bool: @@ -217,8 +217,8 @@ class ModelStopResponse: message: Message stop_reason: StopReason - stop_response: Optional[ModelStopResponse] = None - exception: Optional[Exception] = None + stop_response: ModelStopResponse | None = None + exception: Exception | None = None @property def should_reverse_callbacks(self) -> bool: diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 1efc0bf5b..9839a0841 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,8 +9,9 @@ import inspect import logging +from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar from ..interrupt import Interrupt, InterruptException @@ -153,9 +154,9 @@ class HookRegistry: def __init__(self) -> None: """Initialize an empty hook registry.""" - self._registered_callbacks: dict[Type, list[HookCallback]] = {} + self._registered_callbacks: dict[type, list[HookCallback]] = {} - def add_callback(self, event_type: Type[TEvent], callback: HookCallback[TEvent]) -> None: + def add_callback(self, event_type: type[TEvent], callback: HookCallback[TEvent]) -> None: """Register a callback function for a specific event type. Args: diff --git a/src/strands/models/_validation.py b/src/strands/models/_validation.py index 9eabe28a1..1e82bca73 100644 --- a/src/strands/models/_validation.py +++ b/src/strands/models/_validation.py @@ -1,14 +1,15 @@ """Configuration validation utilities for model providers.""" import warnings -from typing import Any, Mapping, Type +from collections.abc import Mapping +from typing import Any from typing_extensions import get_type_hints from ..types.tools import ToolChoice -def validate_config_keys(config_dict: Mapping[str, Any], config_class: Type) -> None: +def validate_config_keys(config_dict: Mapping[str, Any], config_class: type) -> None: """Validate that config keys match the TypedDict fields. Args: diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 68b234729..535c820ee 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import anthropic from pydantic import BaseModel @@ -59,9 +60,9 @@ class AnthropicConfig(TypedDict, total=False): max_tokens: Required[int] model_id: Required[str] - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, *, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[AnthropicConfig]): + def __init__(self, *, client_args: dict[str, Any] | None = None, **model_config: Unpack[AnthropicConfig]): """Initialize provider instance. Args: @@ -198,8 +199,8 @@ def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format an Anthropic streaming request. @@ -369,8 +370,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -419,8 +420,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4a7c81672..e83f3da8b 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -8,7 +8,8 @@ import logging import os import warnings -from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union, ValuesView, cast +from collections.abc import AsyncGenerator, Callable, Iterable, ValuesView +from typing import Any, Literal, TypeVar, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -92,34 +93,34 @@ class BedrockConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature) """ - additional_args: Optional[dict[str, Any]] - additional_request_fields: Optional[dict[str, Any]] - additional_response_field_paths: Optional[list[str]] - cache_prompt: Optional[str] - cache_tools: Optional[str] - guardrail_id: Optional[str] - guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] - guardrail_stream_processing_mode: Optional[Literal["sync", "async"]] - guardrail_version: Optional[str] - guardrail_redact_input: Optional[bool] - guardrail_redact_input_message: Optional[str] - guardrail_redact_output: Optional[bool] - guardrail_redact_output_message: Optional[str] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + additional_request_fields: dict[str, Any] | None + additional_response_field_paths: list[str] | None + cache_prompt: str | None + cache_tools: str | None + guardrail_id: str | None + guardrail_trace: Literal["enabled", "disabled", "enabled_full"] | None + guardrail_stream_processing_mode: Literal["sync", "async"] | None + guardrail_version: str | None + guardrail_redact_input: bool | None + guardrail_redact_input_message: str | None + guardrail_redact_output: bool | None + guardrail_redact_output_message: str | None + max_tokens: int | None model_id: str - include_tool_result_status: Optional[Literal["auto"] | bool] - stop_sequences: Optional[list[str]] - streaming: Optional[bool] - temperature: Optional[float] - top_p: Optional[float] + include_tool_result_status: Literal["auto"] | bool | None + stop_sequences: list[str] | None + streaming: bool | None + temperature: float | None + top_p: float | None def __init__( self, *, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, - endpoint_url: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, + endpoint_url: str | None = None, **model_config: Unpack[BedrockConfig], ): """Initialize provider instance. @@ -190,8 +191,8 @@ def get_config(self) -> BedrockConfig: def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> dict[str, Any]: """Format a Bedrock converse stream request. @@ -594,11 +595,11 @@ def _generate_redaction_events(self) -> list[StreamEvent]: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the Bedrock model. @@ -622,13 +623,13 @@ async def stream( ModelThrottledException: If the model service is throttling requests. """ - def callback(event: Optional[StreamEvent] = None) -> None: + def callback(event: StreamEvent | None = None) -> None: loop.call_soon_threadsafe(queue.put_nowait, event) if event is None: return loop = asyncio.get_event_loop() - queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue() + queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue() # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None if system_prompt and system_prompt_content is None: @@ -650,8 +651,8 @@ def _stream( self, callback: Callable[..., None], messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, tool_choice: ToolChoice | None = None, ) -> None: """Stream conversation with the Bedrock model. @@ -904,11 +905,11 @@ def _find_detected_and_blocked_policy(self, input: Any) -> bool: @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -953,7 +954,7 @@ async def structured_output( yield {"output": output_model(**output_response)} @staticmethod - def _get_default_model_with_warning(region_name: str, model_config: Optional[BedrockConfig] = None) -> str: + def _get_default_model_with_warning(region_name: str, model_config: BedrockConfig | None = None) -> str: """Get the default Bedrock modelId based on region. If the region is not **known** to support inference then we show a helpful warning diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c24d91a0d..1a7638a13 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -6,7 +6,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import pydantic from google import genai @@ -48,7 +49,7 @@ class GeminiConfig(TypedDict, total=False): def __init__( self, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[GeminiConfig], ) -> None: """Initialize provider instance. @@ -170,7 +171,7 @@ def _format_request_content(self, messages: Messages) -> list[genai.types.Conten for message in messages ] - def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[genai.types.Tool | Any]: + def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai.types.Tool | Any]: """Format tool specs into Gemini tools. - Docs: https://googleapis.github.io/python-genai/genai.html#genai.types.Tool @@ -196,9 +197,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge def _format_request_config( self, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> genai.types.GenerateContentConfig: """Format Gemini request config. @@ -221,9 +222,9 @@ def _format_request_config( def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]], - system_prompt: Optional[str], - params: Optional[dict[str, Any]], + tool_specs: list[ToolSpec] | None, + system_prompt: str | None, + params: dict[str, Any] | None, ) -> dict[str, Any]: """Format a Gemini streaming request. @@ -342,8 +343,8 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: @@ -427,8 +428,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model using Gemini's native structured output. - Docs: https://ai.google.dev/gemini-api/docs/structured-output diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1f1e999d2..7c2890d46 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import litellm from litellm.exceptions import ContextWindowExceededError @@ -42,9 +43,9 @@ class LiteLLMConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[LiteLLMConfig]) -> None: + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[LiteLLMConfig]) -> None: """Initialize provider instance. Args: @@ -137,9 +138,9 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for LiteLLM with cache point support. @@ -174,9 +175,9 @@ def _format_system_messages( def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format a LiteLLM compatible messages array with cache point support. @@ -243,11 +244,11 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -341,8 +342,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Some models do not support native structured output via response_format. @@ -368,7 +369,7 @@ async def structured_output( yield {"output": result} async def _structured_output_using_response_schema( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using native response_format support.""" response = await litellm.acompletion( @@ -396,7 +397,7 @@ async def _structured_output_using_response_schema( raise ValueError(f"Failed to parse or load content into model: {e}") from e async def _structured_output_using_tool( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None ) -> T: """Get structured output using tool calling fallback.""" tool_spec = convert_pydantic_to_tool_spec(output_model) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 013cd2c7d..ce0367bf5 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -43,16 +44,16 @@ class LlamaConfig(TypedDict, total=False): """ model_id: str - repetition_penalty: Optional[float] - temperature: Optional[float] - top_p: Optional[float] - max_completion_tokens: Optional[int] - top_k: Optional[int] + repetition_penalty: float | None + temperature: float | None + top_p: float | None + max_completion_tokens: int | None + top_k: int | None def __init__( self, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[LlamaConfig], ) -> None: """Initialize provider instance. @@ -159,7 +160,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": [self._format_request_message_content(content) for content in contents], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a LlamaAPI compatible messages array. Args: @@ -206,7 +207,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Llama API chat streaming request. @@ -328,8 +329,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -416,8 +417,8 @@ async def stream( @override def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/llamacpp.py b/src/strands/models/llamacpp.py index 22a3a3873..ca838f3d7 100644 --- a/src/strands/models/llamacpp.py +++ b/src/strands/models/llamacpp.py @@ -14,15 +14,11 @@ import logging import mimetypes import time +from collections.abc import AsyncGenerator from typing import ( Any, - AsyncGenerator, - Dict, - Optional, - Type, TypedDict, TypeVar, - Union, cast, ) @@ -133,12 +129,12 @@ class LlamaCppConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None def __init__( self, base_url: str = "http://localhost:8080", - timeout: Optional[Union[float, tuple[float, float]]] = None, + timeout: float | tuple[float, float] | None = None, **model_config: Unpack[LlamaCppConfig], ) -> None: """Initialize llama.cpp provider instance. @@ -196,7 +192,7 @@ def get_config(self) -> LlamaCppConfig: """ return self.config # type: ignore[return-value] - def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) -> dict[str, Any]: + def _format_message_content(self, content: ContentBlock | dict[str, Any]) -> dict[str, Any]: """Format a content block for llama.cpp. Args: @@ -233,7 +229,7 @@ def _format_message_content(self, content: Union[ContentBlock, Dict[str, Any]]) # Handle audio content (not in standard ContentBlock but supported by llama.cpp) if "audio" in content: - audio_content = cast(Dict[str, Any], content) + audio_content = cast(dict[str, Any], content) audio_data = base64.b64encode(audio_content["audio"]["source"]["bytes"]).decode("utf-8") audio_format = audio_content["audio"].get("format", "wav") return { @@ -284,7 +280,7 @@ def _format_tool_message(self, tool_result: dict[str, Any]) -> dict[str, Any]: "content": [self._format_message_content(content) for content in contents], } - def _format_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format messages for llama.cpp. Args: @@ -343,8 +339,8 @@ def _format_messages(self, messages: Messages, system_prompt: Optional[str] = No def _format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, ) -> dict[str, Any]: """Format a request for the llama.cpp server. @@ -428,7 +424,7 @@ def _format_request( request[param] = value # Collect llama.cpp-specific parameters for extra_body - extra_body: Dict[str, Any] = {} + extra_body: dict[str, Any] = {} for param, value in params.items(): if param in llamacpp_specific_params: extra_body[param] = value @@ -511,8 +507,8 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -552,7 +548,7 @@ async def stream( yield self._format_chunk({"chunk_type": "message_start"}) yield self._format_chunk({"chunk_type": "content_start", "data_type": "text"}) - tool_calls: Dict[int, list] = {} + tool_calls: dict[int, list] = {} usage_data = None finish_reason = None @@ -706,11 +702,11 @@ async def stream( @override async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output using llama.cpp's native JSON schema support. This implementation uses llama.cpp's json_schema parameter to constrain @@ -753,7 +749,7 @@ async def structured_output( if "text" in delta: response_text += delta["text"] # Forward events to caller - yield cast(Dict[str, Union[T, Any]], event) + yield cast(dict[str, T | Any], event) # Parse and validate the JSON response data = json.loads(response_text.strip()) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index b6459d63f..4ec77ccfe 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -6,7 +6,8 @@ import base64 import json import logging -from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable +from typing import Any, TypeVar import mistralai from pydantic import BaseModel @@ -47,16 +48,16 @@ class MistralConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - temperature: Optional[float] - top_p: Optional[float] - stream: Optional[bool] + max_tokens: int | None + temperature: float | None + top_p: float | None + stream: bool | None def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, *, - client_args: Optional[dict[str, Any]] = None, + client_args: dict[str, Any] | None = None, **model_config: Unpack[MistralConfig], ) -> None: """Initialize provider instance. @@ -115,7 +116,7 @@ def get_config(self) -> MistralConfig: """ return self.config - def _format_request_message_content(self, content: ContentBlock) -> Union[str, dict[str, Any]]: + def _format_request_message_content(self, content: ContentBlock) -> str | dict[str, Any]: """Format a Mistral content block. Args: @@ -187,7 +188,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "tool_call_id": tool_result["toolUseId"], } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Mistral compatible messages array. Args: @@ -236,7 +237,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return formatted_messages def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format a Mistral chat streaming request. @@ -395,8 +396,8 @@ def _handle_non_streaming_response(self, response: Any) -> Iterable[dict[str, An async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -502,8 +503,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/model.py b/src/strands/models/model.py index b2fa73802..3e0d82a12 100644 --- a/src/strands/models/model.py +++ b/src/strands/models/model.py @@ -2,7 +2,8 @@ import abc import logging -from typing import Any, AsyncGenerator, AsyncIterable, Optional, Type, TypeVar, Union +from collections.abc import AsyncGenerator, AsyncIterable +from typing import Any, TypeVar from pydantic import BaseModel @@ -45,8 +46,8 @@ def get_config(self) -> Any: @abc.abstractmethod # pragma: no cover def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: @@ -68,8 +69,8 @@ def structured_output( def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, system_prompt_content: list[SystemContentBlock] | None = None, diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 574b24200..8d72aa534 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,7 +5,8 @@ import json import logging -from typing import Any, AsyncGenerator, Optional, Type, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypeVar, cast import ollama from pydantic import BaseModel @@ -46,20 +47,20 @@ class OllamaConfig(TypedDict, total=False): top_p: Controls diversity via nucleus sampling (alternative to temperature). """ - additional_args: Optional[dict[str, Any]] - keep_alive: Optional[str] - max_tokens: Optional[int] + additional_args: dict[str, Any] | None + keep_alive: str | None + max_tokens: int | None model_id: str - options: Optional[dict[str, Any]] - stop_sequences: Optional[list[str]] - temperature: Optional[float] - top_p: Optional[float] + options: dict[str, Any] | None + stop_sequences: list[str] | None + temperature: float | None + top_p: float | None def __init__( self, - host: Optional[str], + host: str | None, *, - ollama_client_args: Optional[dict[str, Any]] = None, + ollama_client_args: dict[str, Any] | None = None, **model_config: Unpack[OllamaConfig], ) -> None: """Initialize provider instance. @@ -147,7 +148,7 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) -> raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format an Ollama compatible messages array. Args: @@ -167,7 +168,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s ] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> dict[str, Any]: """Format an Ollama chat streaming request. @@ -285,8 +286,8 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -339,8 +340,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 435c82cab..246e32a7d 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Optional, Protocol, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, Protocol, TypedDict, TypeVar, cast import openai from openai.types.chat.parsed_chat_completion import ParsedChatCompletion @@ -53,9 +54,9 @@ class OpenAIConfig(TypedDict, total=False): """ model_id: str - params: Optional[dict[str, Any]] + params: dict[str, Any] | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[OpenAIConfig]) -> None: """Initialize provider instance. Args: @@ -203,9 +204,9 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str @classmethod def _format_system_messages( cls, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format system messages for OpenAI-compatible providers. @@ -279,9 +280,9 @@ def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dic def format_request_messages( cls, messages: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, *, - system_prompt_content: Optional[list[SystemContentBlock]] = None, + system_prompt_content: list[SystemContentBlock] | None = None, **kwargs: Any, ) -> list[dict[str, Any]]: """Format an OpenAI compatible messages array. @@ -426,8 +427,8 @@ def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -556,8 +557,8 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 7f8b8ff51..7b633d85b 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -3,8 +3,9 @@ import json import logging import os +from collections.abc import AsyncGenerator from dataclasses import dataclass -from typing import Any, AsyncGenerator, Literal, Optional, Type, TypedDict, TypeVar, Union +from typing import Any, Literal, TypedDict, TypeVar import boto3 from botocore.config import Config as BotocoreConfig @@ -37,7 +38,7 @@ class UsageMetadata: total_tokens: int completion_tokens: int prompt_tokens: int - prompt_tokens_details: Optional[int] = 0 + prompt_tokens_details: int | None = 0 @dataclass @@ -49,8 +50,8 @@ class FunctionCall: arguments: Arguments to pass to the function """ - name: Union[str, dict[Any, Any]] - arguments: Union[str, dict[Any, Any]] + name: str | dict[Any, Any] + arguments: str | dict[Any, Any] def __init__(self, **kwargs: dict[str, str]): """Initialize function call. @@ -108,12 +109,12 @@ class SageMakerAIPayloadSchema(TypedDict, total=False): max_tokens: int stream: bool - temperature: Optional[float] - top_p: Optional[float] - top_k: Optional[int] - stop: Optional[list[str]] - tool_results_as_user_messages: Optional[bool] - additional_args: Optional[dict[str, Any]] + temperature: float | None + top_p: float | None + top_k: int | None + stop: list[str] | None + tool_results_as_user_messages: bool | None + additional_args: dict[str, Any] | None class SageMakerAIEndpointConfig(TypedDict, total=False): """Configuration options for SageMaker models. @@ -127,17 +128,17 @@ class SageMakerAIEndpointConfig(TypedDict, total=False): endpoint_name: str region_name: str - inference_component_name: Union[str, None] - target_model: Union[Optional[str], None] - target_variant: Union[Optional[str], None] - additional_args: Optional[dict[str, Any]] + inference_component_name: str | None + target_model: str | None | None + target_variant: str | None | None + additional_args: dict[str, Any] | None def __init__( self, endpoint_config: SageMakerAIEndpointConfig, payload_config: SageMakerAIPayloadSchema, - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, ): """Initialize provider instance. @@ -199,8 +200,8 @@ def get_config(self) -> "SageMakerAIModel.SageMakerAIEndpointConfig": # type: i def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, **kwargs: Any, ) -> dict[str, Any]: @@ -300,8 +301,8 @@ def format_request( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -572,8 +573,8 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index a54fc44c3..f306d649b 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -7,7 +7,8 @@ import json import logging import mimetypes -from typing import Any, AsyncGenerator, Dict, List, Optional, Type, TypedDict, TypeVar, Union, cast +from collections.abc import AsyncGenerator +from typing import Any, TypedDict, TypeVar, cast import writerai from pydantic import BaseModel @@ -41,13 +42,13 @@ class WriterConfig(TypedDict, total=False): """ model_id: str - max_tokens: Optional[int] - stop: Optional[Union[str, List[str]]] - stream_options: Dict[str, Any] - temperature: Optional[float] - top_p: Optional[float] + max_tokens: int | None + stop: str | list[str] | None + stream_options: dict[str, Any] + temperature: float | None + top_p: float | None - def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[WriterConfig]): + def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Unpack[WriterConfig]): """Initialize provider instance. Args: @@ -201,7 +202,7 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any "content": formatted_contents, } - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + def _format_request_messages(self, messages: Messages, system_prompt: str | None = None) -> list[dict[str, Any]]: """Format a Writer compatible messages array. Args: @@ -245,7 +246,7 @@ def _format_request_messages(self, messages: Messages, system_prompt: Optional[s return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: """Format a streaming request to the underlying model. @@ -353,8 +354,8 @@ def format_chunk(self, event: Any) -> StreamEvent: async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, *, tool_choice: ToolChoice | None = None, **kwargs: Any, @@ -431,8 +432,8 @@ async def stream( @override async def structured_output( - self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any - ) -> AsyncGenerator[dict[str, Union[T, Any]], None]: + self, output_model: type[T], prompt: Messages, system_prompt: str | None = None, **kwargs: Any + ) -> AsyncGenerator[dict[str, T | Any], None]: """Get structured output from the model. Args: diff --git a/src/strands/multiagent/a2a/executor.py b/src/strands/multiagent/a2a/executor.py index 52b6d2ef1..f02b8c6cc 100644 --- a/src/strands/multiagent/a2a/executor.py +++ b/src/strands/multiagent/a2a/executor.py @@ -313,15 +313,13 @@ def _convert_a2a_parts_to_content_blocks(self, parts: list[Part]) -> list[Conten elif uri_data: # For URI files, create a text representation since Strands ContentBlocks expect bytes content_blocks.append( - ContentBlock( - text="[File: %s (%s)] - Referenced file at: %s" % (file_name, mime_type, uri_data) - ) + ContentBlock(text=f"[File: {file_name} ({mime_type})] - Referenced file at: {uri_data}") ) elif isinstance(part_root, DataPart): # Handle DataPart - convert structured data to JSON text try: data_text = json.dumps(part_root.data, indent=2) - content_blocks.append(ContentBlock(text="[Structured Data]\n%s" % data_text)) + content_blocks.append(ContentBlock(text=f"[Structured Data]\n{data_text}")) except Exception: logger.exception("Failed to serialize data part") except Exception: diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f163d05b5..dc3258f68 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -6,9 +6,10 @@ import logging import warnings from abc import ABC, abstractmethod +from collections.abc import AsyncIterator, Mapping from dataclasses import dataclass, field from enum import Enum -from typing import Any, AsyncIterator, Mapping, Union +from typing import Any, Union from .._async import run_async from ..agent import AgentResult @@ -95,7 +96,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": raise TypeError("NodeResult.from_dict: missing 'result'") raw = data["result"] - result: Union[AgentResult, "MultiAgentResult", Exception] + result: AgentResult | MultiAgentResult | Exception if isinstance(raw, dict) and raw.get("type") == "agent_result": result = AgentResult.from_dict(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 6156d332c..19504ad73 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -18,8 +18,9 @@ import copy import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, cast from opentelemetry import trace as trace_api @@ -90,14 +91,14 @@ class GraphState: # Graph structure info total_nodes: int = 0 - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) def should_continue( self, - max_node_executions: Optional[int], - execution_timeout: Optional[float], - ) -> Tuple[bool, str]: + max_node_executions: int | None, + execution_timeout: float | None, + ) -> tuple[bool, str]: """Check if the graph should continue execution. Returns: (should_continue, reason) @@ -123,7 +124,7 @@ class GraphResult(MultiAgentResult): completed_nodes: int = 0 failed_nodes: int = 0 execution_order: list["GraphNode"] = field(default_factory=list) - edges: list[Tuple["GraphNode", "GraphNode"]] = field(default_factory=list) + edges: list[tuple["GraphNode", "GraphNode"]] = field(default_factory=list) entry_points: list["GraphNode"] = field(default_factory=list) @@ -233,13 +234,13 @@ def __init__(self) -> None: self.entry_points: set[GraphNode] = set() # Configuration options - self._max_node_executions: Optional[int] = None - self._execution_timeout: Optional[float] = None - self._node_timeout: Optional[float] = None + self._max_node_executions: int | None = None + self._execution_timeout: float | None = None + self._node_timeout: float | None = None self._reset_on_revisit: bool = False self._id: str = _DEFAULT_GRAPH_ID - self._session_manager: Optional[SessionManager] = None - self._hooks: Optional[list[HookProvider]] = None + self._session_manager: SessionManager | None = None + self._hooks: list[HookProvider] | None = None def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" @@ -408,14 +409,14 @@ def __init__( nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_points: set[GraphNode], - max_node_executions: Optional[int] = None, - execution_timeout: Optional[float] = None, - node_timeout: Optional[float] = None, + max_node_executions: int | None = None, + execution_timeout: float | None = None, + node_timeout: float | None = None, reset_on_revisit: bool = False, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_GRAPH_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: """Initialize Graph with execution limits and reset behavior. diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7eec49649..6c1149624 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -18,8 +18,9 @@ import json import logging import time +from collections.abc import AsyncIterator, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, AsyncIterator, Callable, Mapping, Optional, Tuple, cast +from typing import Any, Optional, cast from opentelemetry import trace as trace_api @@ -184,7 +185,7 @@ def should_continue( execution_timeout: float, repetitive_handoff_detection_window: int, repetitive_handoff_min_unique_agents: int, - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if the swarm should continue. Returns: (should_continue, reason) @@ -239,10 +240,10 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, - session_manager: Optional[SessionManager] = None, - hooks: Optional[list[HookProvider]] = None, + session_manager: SessionManager | None = None, + hooks: list[HookProvider] | None = None, id: str = _DEFAULT_SWARM_ID, - trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> None: """Initialize Swarm with agents and configuration. diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index fc80fc520..0b25d4b5d 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,7 +5,7 @@ import os import shutil import tempfile -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, cast from .. import _identifier from ..types.exceptions import SessionException @@ -44,7 +44,7 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): def __init__( self, session_id: str, - storage_dir: Optional[str] = None, + storage_dir: str | None = None, **kwargs: Any, ): """Initialize FileSession with filesystem storage. @@ -108,7 +108,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> def _read_file(self, path: str) -> dict[str, Any]: """Read JSON file.""" try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return cast(dict[str, Any], json.load(f)) except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in file {path}: {str(e)}") from e @@ -140,7 +140,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data.""" session_file = os.path.join(self._get_session_path(session_id), "session.json") if not os.path.exists(session_file): @@ -169,7 +169,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A session_data = session_agent.to_dict() self._write_file(agent_file, session_data) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data.""" agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json") if not os.path.exists(agent_file): @@ -199,7 +199,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio session_dict = session_message.to_dict() self._write_file(message_file, session_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data.""" message_path = self._get_message_path(session_id, agent_id, message_id) if not os.path.exists(message_path): @@ -220,7 +220,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_file(message_file, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List messages for an agent with pagination.""" messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages") @@ -269,7 +269,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_file(multi_agent_file, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from filesystem.""" multi_agent_file = os.path.join(self._get_multi_agent_path(session_id, multi_agent_id), "multi_agent.json") if not os.path.exists(multi_agent_file): diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index a8ac099d9..d23c4a94f 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -1,7 +1,7 @@ """Repository session manager implementation.""" import logging -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..agent.state import AgentState from ..tools._tool_helpers import generate_missing_tool_result_content @@ -57,7 +57,7 @@ def __init__( self.session = session # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + self._latest_agent_message: dict[str, SessionMessage | None] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 7d081cf09..e5713e5b7 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -47,9 +47,9 @@ def __init__( session_id: str, bucket: str, prefix: str = "", - boto_session: Optional[boto3.Session] = None, - boto_client_config: Optional[BotocoreConfig] = None, - region_name: Optional[str] = None, + boto_session: boto3.Session | None = None, + boto_client_config: BotocoreConfig | None = None, + region_name: str | None = None, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -130,7 +130,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> agent_path = self._get_agent_path(session_id, agent_id) return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json" - def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: + def _read_s3_object(self, key: str) -> dict[str, Any] | None: """Read JSON object from S3.""" try: response = self.client.get_object(Bucket=self.bucket, Key=key) @@ -144,7 +144,7 @@ def _read_s3_object(self, key: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError as e: raise SessionException(f"Invalid JSON in S3 object {key}: {e}") from e - def _write_s3_object(self, key: str, data: Dict[str, Any]) -> None: + def _write_s3_object(self, key: str, data: dict[str, Any]) -> None: """Write JSON object to S3.""" try: content = json.dumps(data, indent=2, ensure_ascii=False) @@ -171,7 +171,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: self._write_s3_object(session_key, session_dict) return session - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read session data from S3.""" session_key = f"{self._get_session_path(session_id)}session.json" session_data = self._read_s3_object(session_key) @@ -209,7 +209,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" self._write_s3_object(agent_key, agent_dict) - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read agent data from S3.""" agent_key = f"{self._get_agent_path(session_id, agent_id)}agent.json" agent_data = self._read_s3_object(agent_key) @@ -236,7 +236,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio message_key = self._get_message_path(session_id, agent_id, message_id) self._write_s3_object(message_key, message_dict) - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read message data from S3.""" message_key = self._get_message_path(session_id, agent_id, message_id) message_data = self._read_s3_object(message_key) @@ -257,8 +257,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio self._write_s3_object(message_key, session_message.to_dict()) def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any - ) -> List[SessionMessage]: + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any + ) -> list[SessionMessage]: """List messages for an agent with pagination from S3.""" messages_prefix = f"{self._get_agent_path(session_id, agent_id)}messages/" try: @@ -288,7 +288,7 @@ def list_messages( message_keys = message_keys[offset:] # Load only the required message objects - messages: List[SessionMessage] = [] + messages: list[SessionMessage] = [] for key in message_keys: message_data = self._read_s3_object(key) if message_data: @@ -312,7 +312,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k session_data = multi_agent.serialize_state() self._write_s3_object(multi_agent_key, session_data) - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read multi-agent state from S3.""" multi_agent_key = f"{self._get_multi_agent_path(session_id, multi_agent_id)}multi_agent.json" return self._read_s3_object(multi_agent_key) diff --git a/src/strands/session/session_repository.py b/src/strands/session/session_repository.py index 3f5476bdf..0b6f2c705 100644 --- a/src/strands/session/session_repository.py +++ b/src/strands/session/session_repository.py @@ -1,7 +1,7 @@ """Session repository interface for agent session management.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..types.session import Session, SessionAgent, SessionMessage @@ -17,7 +17,7 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new Session.""" @abstractmethod - def read_session(self, session_id: str, **kwargs: Any) -> Optional[Session]: + def read_session(self, session_id: str, **kwargs: Any) -> Session | None: """Read a Session.""" @abstractmethod @@ -25,7 +25,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent, **kwargs: A """Create a new Agent in a Session.""" @abstractmethod - def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> Optional[SessionAgent]: + def read_agent(self, session_id: str, agent_id: str, **kwargs: Any) -> SessionAgent | None: """Read an Agent.""" @abstractmethod @@ -37,7 +37,7 @@ def create_message(self, session_id: str, agent_id: str, session_message: Sessio """Create a new Message for the Agent.""" @abstractmethod - def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> Optional[SessionMessage]: + def read_message(self, session_id: str, agent_id: str, message_id: int, **kwargs: Any) -> SessionMessage | None: """Read a Message.""" @abstractmethod @@ -49,7 +49,7 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio @abstractmethod def list_messages( - self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0, **kwargs: Any + self, session_id: str, agent_id: str, limit: int | None = None, offset: int = 0, **kwargs: Any ) -> list[SessionMessage]: """List Messages from an Agent with pagination.""" @@ -57,7 +57,7 @@ def create_multi_agent(self, session_id: str, multi_agent: "MultiAgentBase", **k """Create a new MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") - def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> Optional[dict[str, Any]]: + def read_multi_agent(self, session_id: str, multi_agent_id: str, **kwargs: Any) -> dict[str, Any] | None: """Read the MultiAgent state for the Session.""" raise NotImplementedError("MultiAgent is not implemented for this repository") diff --git a/src/strands/telemetry/metrics.py b/src/strands/telemetry/metrics.py index abfbbffae..7e0e09436 100644 --- a/src/strands/telemetry/metrics.py +++ b/src/strands/telemetry/metrics.py @@ -3,8 +3,9 @@ import logging import time import uuid +from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Optional import opentelemetry.metrics as metrics_api from opentelemetry.metrics import Counter, Histogram, Meter @@ -23,11 +24,11 @@ class Trace: def __init__( self, name: str, - parent_id: Optional[str] = None, - start_time: Optional[float] = None, - raw_name: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - message: Optional[Message] = None, + parent_id: str | None = None, + start_time: float | None = None, + raw_name: str | None = None, + metadata: dict[str, Any] | None = None, + message: Message | None = None, ) -> None: """Initialize a new trace. @@ -42,15 +43,15 @@ def __init__( """ self.id: str = str(uuid.uuid4()) self.name: str = name - self.raw_name: Optional[str] = raw_name - self.parent_id: Optional[str] = parent_id + self.raw_name: str | None = raw_name + self.parent_id: str | None = parent_id self.start_time: float = start_time if start_time is not None else time.time() - self.end_time: Optional[float] = None - self.children: List["Trace"] = [] - self.metadata: Dict[str, Any] = metadata or {} - self.message: Optional[Message] = message + self.end_time: float | None = None + self.children: list[Trace] = [] + self.metadata: dict[str, Any] = metadata or {} + self.message: Message | None = message - def end(self, end_time: Optional[float] = None) -> None: + def end(self, end_time: float | None = None) -> None: """Mark the trace as complete with the given or current timestamp. Args: @@ -67,7 +68,7 @@ def add_child(self, child: "Trace") -> None: """ self.children.append(child) - def duration(self) -> Optional[float]: + def duration(self) -> float | None: """Calculate the duration of this trace. Returns: @@ -83,7 +84,7 @@ def add_message(self, message: Message) -> None: """ self.message = message - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert the trace to a dictionary representation. Returns: @@ -127,7 +128,7 @@ def add_call( duration: float, success: bool, metrics_client: "MetricsClient", - attributes: Optional[Dict[str, Any]] = None, + attributes: dict[str, Any] | None = None, ) -> None: """Record a new tool call with its outcome. @@ -165,9 +166,9 @@ class EventLoopMetrics: """ cycle_count: int = 0 - tool_metrics: Dict[str, ToolMetrics] = field(default_factory=dict) - cycle_durations: List[float] = field(default_factory=list) - traces: List[Trace] = field(default_factory=list) + tool_metrics: dict[str, ToolMetrics] = field(default_factory=dict) + cycle_durations: list[float] = field(default_factory=list) + traces: list[Trace] = field(default_factory=list) accumulated_usage: Usage = field(default_factory=lambda: Usage(inputTokens=0, outputTokens=0, totalTokens=0)) accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) @@ -178,8 +179,8 @@ def _metrics_client(self) -> "MetricsClient": def start_cycle( self, - attributes: Optional[Dict[str, Any]] = None, - ) -> Tuple[float, Trace]: + attributes: dict[str, Any] | None = None, + ) -> tuple[float, Trace]: """Start a new event loop cycle and create a trace for it. Args: @@ -196,7 +197,7 @@ def start_cycle( self.traces.append(cycle_trace) return start_time, cycle_trace - def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: Optional[Dict[str, Any]] = None) -> None: + def end_cycle(self, start_time: float, cycle_trace: Trace, attributes: dict[str, Any] | None = None) -> None: """End the current event loop cycle and record its duration. Args: @@ -290,7 +291,7 @@ def update_metrics(self, metrics: Metrics) -> None: self._metrics_client.model_time_to_first_token.record(metrics["timeToFirstByteMs"]) self.accumulated_metrics["latencyMs"] += metrics["latencyMs"] - def get_summary(self) -> Dict[str, Any]: + def get_summary(self) -> dict[str, Any]: """Generate a comprehensive summary of all collected metrics. Returns: @@ -326,7 +327,7 @@ def get_summary(self) -> Dict[str, Any]: return summary -def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: Set[str]) -> Iterable[str]: +def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_names: set[str]) -> Iterable[str]: """Convert event loop metrics to a series of formatted text lines. Args: @@ -387,7 +388,7 @@ def _metrics_summary_to_lines(event_loop_metrics: EventLoopMetrics, allowed_name yield from _trace_to_lines(trace.to_dict(), allowed_names=allowed_names, indent=1) -def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterable[str]: +def _trace_to_lines(trace: dict, allowed_names: set[str], indent: int) -> Iterable[str]: """Convert a trace to a series of formatted text lines. Args: @@ -419,7 +420,7 @@ def _trace_to_lines(trace: Dict, allowed_names: Set[str], indent: int) -> Iterab yield from _trace_to_lines(child, allowed_names, indent + 1) -def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: Optional[Set[str]] = None) -> str: +def metrics_to_string(event_loop_metrics: EventLoopMetrics, allowed_names: set[str] | None = None) -> str: """Convert event loop metrics to a human-readable string representation. Args: diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 2f42d9988..68f501b97 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,8 +7,9 @@ import json import logging import os +from collections.abc import Mapping from datetime import date, datetime, timezone -from typing import Any, Dict, Mapping, Optional, cast +from typing import Any, cast import opentelemetry.trace as trace_api from opentelemetry.instrumentation.threading import ThreadingInstrumentor @@ -89,7 +90,7 @@ class Tracer: def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ - self.tracer_provider: Optional[trace_api.TracerProvider] = None + self.tracer_provider: trace_api.TracerProvider | None = None self.tracer_provider = trace_api.get_tracer_provider() self.tracer = self.tracer_provider.get_tracer(self.service_name) ThreadingInstrumentor().instrument() @@ -112,8 +113,8 @@ def _parse_semconv_opt_in(self) -> set[str]: def _start_span( self, span_name: str, - parent_span: Optional[Span] = None, - attributes: Optional[Dict[str, AttributeValue]] = None, + parent_span: Span | None = None, + attributes: dict[str, AttributeValue] | None = None, span_kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL, ) -> Span: """Generic helper method to start a span with common attributes. @@ -145,7 +146,7 @@ def _start_span( return span - def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> None: + def _set_attributes(self, span: Span, attributes: dict[str, AttributeValue]) -> None: """Set attributes on a span, handling different value types appropriately. Args: @@ -159,7 +160,7 @@ def _set_attributes(self, span: Span, attributes: Dict[str, AttributeValue]) -> span.set_attribute(key, value) def _add_optional_usage_and_metrics_attributes( - self, attributes: Dict[str, AttributeValue], usage: Usage, metrics: Metrics + self, attributes: dict[str, AttributeValue], usage: Usage, metrics: Metrics ) -> None: """Add optional usage and metrics attributes if they have values. @@ -183,8 +184,8 @@ def _add_optional_usage_and_metrics_attributes( def _end_span( self, span: Span, - attributes: Optional[Dict[str, AttributeValue]] = None, - error: Optional[Exception] = None, + attributes: dict[str, AttributeValue] | None = None, + error: Exception | None = None, ) -> None: """Generic helper method to end a span. @@ -221,7 +222,7 @@ def _end_span( except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Optional[Exception] = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: """End a span with error status. Args: @@ -235,7 +236,7 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Optiona error = exception or Exception(error_message) self._end_span(span, error=error) - def _add_event(self, span: Optional[Span], event_name: str, event_attributes: Attributes) -> None: + def _add_event(self, span: Span | None, event_name: str, event_attributes: Attributes) -> None: """Add an event with attributes to a span. Args: @@ -275,9 +276,9 @@ def _get_event_name_for_message(self, message: Message) -> str: def start_model_invoke_span( self, messages: Messages, - parent_span: Optional[Span] = None, - model_id: Optional[str] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + model_id: str | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a model invocation. @@ -292,7 +293,7 @@ def start_model_invoke_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="chat") if custom_trace_attributes: attributes.update(custom_trace_attributes) @@ -315,7 +316,7 @@ def end_model_invoke_span( usage: Usage, metrics: Metrics, stop_reason: StopReason, - error: Optional[Exception] = None, + error: Exception | None = None, ) -> None: """End a model invocation span with results and metrics. @@ -327,7 +328,7 @@ def end_model_invoke_span( stop_reason (StopReason): The reason the model stopped generating. error: Optional exception if the model call failed. """ - attributes: Dict[str, AttributeValue] = { + attributes: dict[str, AttributeValue] = { "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.input_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], @@ -366,8 +367,8 @@ def end_model_invoke_span( def start_tool_call_span( self, tool: ToolUse, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, ) -> Span: """Start a new span for a tool call. @@ -381,7 +382,7 @@ def start_tool_call_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="execute_tool") attributes.update( { "gen_ai.tool.name": tool["name"], @@ -432,9 +433,7 @@ def start_tool_call_span( return span - def end_tool_call_span( - self, span: Span, tool_result: Optional[ToolResult], error: Optional[Exception] = None - ) -> None: + def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: """End a tool call span with results. Args: @@ -442,7 +441,7 @@ def end_tool_call_span( tool_result: The result from the tool execution. error: Optional exception if the tool call failed. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if tool_result is not None: status = tool_result.get("status") status_str = str(status) if status is not None else "" @@ -490,10 +489,10 @@ def start_event_loop_cycle_span( self, invocation_state: Any, messages: Messages, - parent_span: Optional[Span] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + parent_span: Span | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, **kwargs: Any, - ) -> Optional[Span]: + ) -> Span | None: """Start a new span for an event loop cycle. Args: @@ -509,7 +508,7 @@ def start_event_loop_cycle_span( event_loop_cycle_id = str(invocation_state.get("event_loop_cycle_id")) parent_span = parent_span if parent_span else invocation_state.get("event_loop_parent_span") - attributes: Dict[str, AttributeValue] = { + attributes: dict[str, AttributeValue] = { "event_loop.cycle_id": event_loop_cycle_id, } @@ -532,8 +531,8 @@ def end_event_loop_cycle_span( self, span: Span, message: Message, - tool_result_message: Optional[Message] = None, - error: Optional[Exception] = None, + tool_result_message: Message | None = None, + error: Exception | None = None, ) -> None: """End an event loop cycle span with results. @@ -543,8 +542,8 @@ def end_event_loop_cycle_span( tool_result_message: Optional tool result message if a tool was called. error: Optional exception if the cycle failed. """ - attributes: Dict[str, AttributeValue] = {} - event_attributes: Dict[str, AttributeValue] = {"message": serialize(message["content"])} + attributes: dict[str, AttributeValue] = {} + event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])} if tool_result_message: event_attributes["tool.result"] = serialize(tool_result_message["content"]) @@ -572,10 +571,10 @@ def start_agent_span( self, messages: Messages, agent_name: str, - model_id: Optional[str] = None, - tools: Optional[list] = None, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, - tools_config: Optional[dict] = None, + model_id: str | None = None, + tools: list | None = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, + tools_config: dict | None = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -592,7 +591,7 @@ def start_agent_span( Returns: The created span, or None if tracing is not enabled. """ - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation_name="invoke_agent") attributes.update( { "gen_ai.agent.name": agent_name, @@ -630,8 +629,8 @@ def start_agent_span( def end_agent_span( self, span: Span, - response: Optional[AgentResult] = None, - error: Optional[Exception] = None, + response: AgentResult | None = None, + error: Exception | None = None, ) -> None: """End an agent span with results and metrics. @@ -640,7 +639,7 @@ def end_agent_span( response: The response from the agent. error: Any error that occurred. """ - attributes: Dict[str, AttributeValue] = {} + attributes: dict[str, AttributeValue] = {} if response: if self.use_latest_genai_conventions: @@ -698,11 +697,11 @@ def start_multiagent_span( self, task: MultiAgentInput, instance: str, - custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + custom_trace_attributes: Mapping[str, AttributeValue] | None = None, ) -> Span: """Start a new span for swarm invocation.""" operation = f"invoke_{instance}" - attributes: Dict[str, AttributeValue] = self._get_common_attributes(operation) + attributes: dict[str, AttributeValue] = self._get_common_attributes(operation) attributes.update( { "gen_ai.agent.name": instance, @@ -737,7 +736,7 @@ def start_multiagent_span( def end_swarm_span( self, span: Span, - result: Optional[str] = None, + result: str | None = None, ) -> None: """End a swarm span with results.""" if result: @@ -766,7 +765,7 @@ def end_swarm_span( def _get_common_attributes( self, operation_name: str, - ) -> Dict[str, AttributeValue]: + ) -> dict[str, AttributeValue]: """Returns a dictionary of common attributes based on the convention version used. Args: diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 1e0ca2c8d..c104aa633 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -9,7 +9,8 @@ import json import random -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from .._async import run_async from ..tools.executors._executor import ToolExecutor diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 8dc933f51..f64c17ee9 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -44,16 +44,13 @@ def my_tool(param1: str, param2: int = 42) -> dict: import functools import inspect import logging +from collections.abc import Callable from typing import ( Annotated, Any, - Callable, Generic, - Optional, ParamSpec, - Type, TypeVar, - Union, cast, get_args, get_origin, @@ -183,7 +180,7 @@ def _validate_signature(self) -> None: # Found the parameter, no need to check further break - def _create_input_model(self) -> Type[BaseModel]: + def _create_input_model(self) -> type[BaseModel]: """Create a Pydantic model from function signature for input validation. This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can @@ -463,7 +460,7 @@ def __init__( functools.update_wrapper(wrapper=self, wrapped=self._tool_func) - def __get__(self, instance: Any, obj_type: Optional[Type] = None) -> "DecoratedFunctionTool[P, R]": + def __get__(self, instance: Any, obj_type: type | None = None) -> "DecoratedFunctionTool[P, R]": """Descriptor protocol implementation for proper method binding. This method enables the decorated function to work correctly when used as a class method. @@ -666,20 +663,20 @@ def tool(__func: Callable[P, R]) -> DecoratedFunctionTool[P, R]: ... # Handle @decorator() @overload def tool( - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system def tool( # type: ignore - func: Optional[Callable[P, R]] = None, - description: Optional[str] = None, - inputSchema: Optional[JSONSchema] = None, - name: Optional[str] = None, + func: Callable[P, R] | None = None, + description: str | None = None, + inputSchema: JSONSchema | None = None, + name: str | None = None, context: bool | str = False, -) -> Union[DecoratedFunctionTool[P, R], Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]]: +) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. This decorator seamlessly enables a function to be called both as a regular Python function and as a Strands tool. diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index 5d01c5d48..6d58c5c75 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -7,7 +7,8 @@ import abc import logging import time -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any, cast from opentelemetry import trace as trace_api diff --git a/src/strands/tools/executors/concurrent.py b/src/strands/tools/executors/concurrent.py index 216eee379..7fa34eff0 100644 --- a/src/strands/tools/executors/concurrent.py +++ b/src/strands/tools/executors/concurrent.py @@ -1,7 +1,8 @@ """Concurrent tool executor implementation.""" import asyncio -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override diff --git a/src/strands/tools/executors/sequential.py b/src/strands/tools/executors/sequential.py index f78e60872..dc5b9a5d9 100644 --- a/src/strands/tools/executors/sequential.py +++ b/src/strands/tools/executors/sequential.py @@ -1,6 +1,7 @@ """Sequential tool executor implementation.""" -from typing import TYPE_CHECKING, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Any from typing_extensions import override diff --git a/src/strands/tools/loader.py b/src/strands/tools/loader.py index 6f745b728..2115cdee8 100644 --- a/src/strands/tools/loader.py +++ b/src/strands/tools/loader.py @@ -9,7 +9,7 @@ from pathlib import Path from posixpath import expanduser from types import ModuleType -from typing import List, cast +from typing import cast from ..types.tools import AgentTool from .decorator import DecoratedFunctionTool @@ -20,7 +20,7 @@ _TOOL_MODULE_PREFIX = "_strands_tool_" -def load_tool_from_string(tool_string: str) -> List[AgentTool]: +def load_tool_from_string(tool_string: str) -> list[AgentTool]: """Load tools follows strands supported input string formats. This function can load a tool based on a string in the following ways: @@ -42,7 +42,7 @@ def load_tool_from_string(tool_string: str) -> List[AgentTool]: return load_tools_from_module_path(tool_string) -def load_tools_from_file_path(tool_path: str) -> List[AgentTool]: +def load_tools_from_file_path(tool_path: str) -> list[AgentTool]: """Load module from specified path, and then load tools from that module. This function attempts to load the passed in path as a python module, and if it succeeds, @@ -116,7 +116,7 @@ def load_tools_from_module(module: ModuleType, module_name: str) -> list[AgentTo # Try and see if any of the attributes in the module are function-based tools decorated with @tool # This means that there may be more than one tool available in this module, so we load them all - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] # Function tools will appear as attributes in the module for attr_name in dir(module): attr = getattr(module, attr_name) @@ -153,7 +153,7 @@ class ToolLoader: """Handles loading of tools from different sources.""" @staticmethod - def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: + def load_python_tools(tool_path: str, tool_name: str) -> list[AgentTool]: """DEPRECATED: Load a Python tool module and return all discovered function-based tools as a list. This method always returns a list of AgentTool (possibly length 1). It is the @@ -206,7 +206,7 @@ def load_python_tools(tool_path: str, tool_name: str) -> List[AgentTool]: spec.loader.exec_module(module) # Collect function-based tools decorated with @tool - function_tools: List[AgentTool] = [] + function_tools: list[AgentTool] = [] for attr_name in dir(module): attr = getattr(module, attr_name) if isinstance(attr, DecoratedFunctionTool): diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index bb5dca19c..a80dd648c 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -13,10 +13,12 @@ import threading import uuid from asyncio import AbstractEventLoop +from collections.abc import Callable, Coroutine, Sequence from concurrent import futures from datetime import timedelta +from re import Pattern from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, Optional, Pattern, Sequence, TypeVar, Union, cast +from typing import Any, TypeVar, cast import anyio from mcp import ClientSession, ListToolsResult @@ -61,7 +63,7 @@ class ToolFilters(TypedDict, total=False): rejected: list[_ToolMatcher] -MIME_TO_FORMAT: Dict[str, ImageFormat] = { +MIME_TO_FORMAT: dict[str, ImageFormat] = { "image/jpeg": "jpeg", "image/jpg": "jpeg", "image/png": "png", @@ -107,7 +109,7 @@ def __init__( startup_timeout: int = 30, tool_filters: ToolFilters | None = None, prefix: str | None = None, - elicitation_callback: Optional[ElicitationFnT] = None, + elicitation_callback: ElicitationFnT | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -286,9 +288,7 @@ def remove_consumer(self, consumer_id: Any, **kwargs: Any) -> None: # MCP-specific methods - def stop( - self, exc_type: Optional[BaseException], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] - ) -> None: + def stop(self, exc_type: BaseException | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: """Signals the background thread to stop and waits for it to complete, ensuring proper cleanup of all resources. This method is defensive and can handle partial initialization states that may occur @@ -398,7 +398,7 @@ async def _list_tools_async() -> ListToolsResult: self._log_debug_with_thread("successfully adapted %d MCP tools", len(mcp_tools)) return PaginatedList[MCPAgentTool](mcp_tools, token=list_tools_response.nextCursor) - def list_prompts_sync(self, pagination_token: Optional[str] = None) -> ListPromptsResult: + def list_prompts_sync(self, pagination_token: str | None = None) -> ListPromptsResult: """Synchronously retrieves the list of available prompts from the MCP server. This method calls the asynchronous list_prompts method on the MCP session @@ -644,7 +644,7 @@ def _background_task(self) -> None: def _map_mcp_content_to_tool_result_content( self, content: MCPTextContent | MCPImageContent | MCPEmbeddedResource | Any, - ) -> Union[ToolResultContent, None]: + ) -> ToolResultContent | None: """Maps MCP content types to tool result content types. This method converts MCP-specific content types to the generic @@ -764,7 +764,7 @@ def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" return self._should_include_tool_with_filters(tool, self._tool_filters) - def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: Optional[ToolFilters]) -> bool: + def _should_include_tool_with_filters(self, tool: MCPAgentTool, filters: ToolFilters | None) -> bool: """Check if a tool should be included based on provided filters.""" if not filters: return True diff --git a/src/strands/tools/mcp/mcp_instrumentation.py b/src/strands/tools/mcp/mcp_instrumentation.py index f8ab3bc80..d1750daa3 100644 --- a/src/strands/tools/mcp/mcp_instrumentation.py +++ b/src/strands/tools/mcp/mcp_instrumentation.py @@ -9,9 +9,10 @@ Related issue: https://github.com/modelcontextprotocol/modelcontextprotocol/issues/246 """ +from collections.abc import AsyncGenerator, Callable from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Tuple +from typing import Any from mcp.shared.message import SessionMessage from mcp.types import JSONRPCMessage, JSONRPCRequest @@ -129,7 +130,7 @@ def transport_wrapper() -> Callable[ @asynccontextmanager async def traced_method( wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any - ) -> AsyncGenerator[Tuple[Any, Any], None]: + ) -> AsyncGenerator[tuple[Any, Any], None]: async with wrapped(*args, **kwargs) as result: try: read_stream, write_stream = result @@ -139,7 +140,7 @@ async def traced_method( return traced_method - def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any]], None]: + def session_init_wrapper() -> Callable[[Any, Any, tuple[Any, ...], dict[str, Any]], None]: """Create a wrapper for MCP session initialization. Wraps session message streams to enable bidirectional context flow. @@ -151,7 +152,7 @@ def session_init_wrapper() -> Callable[[Any, Any, Tuple[Any, ...], dict[str, Any """ def traced_method( - wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: dict[str, Any] + wrapped: Callable[..., Any], instance: Any, args: tuple[Any, ...], kwargs: dict[str, Any] ) -> None: wrapped(*args, **kwargs) reader = getattr(instance, "_incoming_message_stream_reader", None) diff --git a/src/strands/tools/mcp/mcp_types.py b/src/strands/tools/mcp/mcp_types.py index 66eda08ae..bc926c51a 100644 --- a/src/strands/tools/mcp/mcp_types.py +++ b/src/strands/tools/mcp/mcp_types.py @@ -1,7 +1,7 @@ """Type definitions for MCP integration.""" from contextlib import AbstractAsyncContextManager -from typing import Any, Dict +from typing import Any from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.client.streamable_http import GetSessionIdCallback @@ -60,4 +60,4 @@ class MCPToolResult(ToolResult): that can be processed programmatically by agents or other tools. """ - structuredContent: NotRequired[Dict[str, Any]] + structuredContent: NotRequired[dict[str, Any]] diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 15150847d..c9be386a6 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -10,12 +10,13 @@ import sys import uuid import warnings +from collections.abc import Iterable, Sequence from importlib import import_module, util from os.path import expanduser from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, cast -from typing_extensions import TypedDict, cast +from typing_extensions import TypedDict from .._async import run_async from ..experimental.tools import ToolProvider @@ -35,13 +36,13 @@ class ToolRegistry: def __init__(self) -> None: """Initialize the tool registry.""" - self.registry: Dict[str, AgentTool] = {} - self.dynamic_tools: Dict[str, AgentTool] = {} - self.tool_config: Optional[Dict[str, Any]] = None - self._tool_providers: List[ToolProvider] = [] + self.registry: dict[str, AgentTool] = {} + self.dynamic_tools: dict[str, AgentTool] = {} + self.tool_config: dict[str, Any] | None = None + self._tool_providers: list[ToolProvider] = [] self._registry_id = str(uuid.uuid4()) - def process_tools(self, tools: List[Any]) -> List[str]: + def process_tools(self, tools: list[Any]) -> list[str]: """Process tools list. Process list of tools that can contain local file path string, module import path string, @@ -186,7 +187,7 @@ def load_tool_from_filepath(self, tool_name: str, tool_path: str) -> None: logger.exception("tool_name=<%s> | failed to load tool", tool_name) raise ValueError(f"Failed to load tool {tool_name}: {exception_str}") from e - def get_all_tools_config(self) -> Dict[str, Any]: + def get_all_tools_config(self) -> dict[str, Any]: """Dynamically generate tool configuration by combining built-in and dynamic tools. Returns: @@ -279,7 +280,7 @@ def register_tool(self, tool: AgentTool) -> None: list(self.dynamic_tools.keys()), ) - def get_tools_dirs(self) -> List[Path]: + def get_tools_dirs(self) -> list[Path]: """Get all tool directory paths. Returns: @@ -299,7 +300,7 @@ def get_tools_dirs(self) -> List[Path]: return tool_dirs - def discover_tool_modules(self) -> Dict[str, Path]: + def discover_tool_modules(self) -> dict[str, Path]: """Discover available tool modules in all tools directories. Returns: @@ -542,7 +543,7 @@ def get_all_tool_specs(self) -> list[ToolSpec]: A list of ToolSpecs. """ all_tools = self.get_all_tools_config() - tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] + tools: list[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] return tools def register_dynamic_tool(self, tool: AgentTool) -> None: @@ -619,7 +620,7 @@ class NewToolDict(TypedDict): spec: ToolSpec - def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict) -> None: + def _update_tool_config(self, tool_config: dict[str, Any], new_tool: NewToolDict) -> None: """Update tool configuration with a new tool. Args: @@ -656,7 +657,7 @@ def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict tool_config["tools"].append(new_tool_entry) logger.debug("tool_name=<%s> | added new tool", new_tool_name) - def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: + def _scan_module_for_tools(self, module: Any) -> list[AgentTool]: """Scan a module for function-based tools. Args: @@ -665,7 +666,7 @@ def _scan_module_for_tools(self, module: Any) -> List[AgentTool]: Returns: List of FunctionTool instances found in the module. """ - tools: List[AgentTool] = [] + tools: list[AgentTool] = [] for name, obj in inspect.getmembers(module): if isinstance(obj, DecoratedFunctionTool): diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py index f33a06915..2f8dd8ca0 100644 --- a/src/strands/tools/structured_output/_structured_output_context.py +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -1,7 +1,7 @@ """Context management for structured output in the event loop.""" import logging -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -17,20 +17,20 @@ class StructuredOutputContext: """Per-invocation context for structured output execution.""" - def __init__(self, structured_output_model: Type[BaseModel] | None = None): + def __init__(self, structured_output_model: type[BaseModel] | None = None): """Initialize a new structured output context. Args: structured_output_model: Optional Pydantic model type for structured output. """ self.results: dict[str, BaseModel] = {} - self.structured_output_model: Type[BaseModel] | None = structured_output_model + self.structured_output_model: type[BaseModel] | None = structured_output_model self.structured_output_tool: StructuredOutputTool | None = None self.forced_mode: bool = False self.force_attempted: bool = False self.tool_choice: ToolChoice | None = None self.stop_loop: bool = False - self.expected_tool_name: Optional[str] = None + self.expected_tool_name: str | None = None if structured_output_model: self.structured_output_tool = StructuredOutputTool(structured_output_model) @@ -91,7 +91,7 @@ def has_structured_output_tool(self, tool_uses: list[ToolUse]) -> bool: return False return any(tool_use.get("name") == self.expected_tool_name for tool_use in tool_uses) - def get_tool_spec(self) -> Optional[ToolSpec]: + def get_tool_spec(self) -> ToolSpec | None: """Get the tool specification for structured output. Returns: diff --git a/src/strands/tools/structured_output/structured_output_tool.py b/src/strands/tools/structured_output/structured_output_tool.py index 25173d048..fa20f526c 100644 --- a/src/strands/tools/structured_output/structured_output_tool.py +++ b/src/strands/tools/structured_output/structured_output_tool.py @@ -6,7 +6,7 @@ import logging from copy import deepcopy -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ValidationError from typing_extensions import override @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -_TOOL_SPEC_CACHE: dict[Type[BaseModel], ToolSpec] = {} +_TOOL_SPEC_CACHE: dict[type[BaseModel], ToolSpec] = {} if TYPE_CHECKING: from ._structured_output_context import StructuredOutputContext @@ -26,7 +26,7 @@ class StructuredOutputTool(AgentTool): """Tool implementation for structured output validation.""" - def __init__(self, structured_output_model: Type[BaseModel]) -> None: + def __init__(self, structured_output_model: type[BaseModel]) -> None: """Initialize a structured output tool. Args: @@ -43,7 +43,7 @@ def __init__(self, structured_output_model: Type[BaseModel]) -> None: self._tool_name = self._tool_spec.get("name", "StructuredOutputTool") @classmethod - def _get_tool_spec(cls, structured_output_model: Type[BaseModel]) -> ToolSpec: + def _get_tool_spec(cls, structured_output_model: type[BaseModel]) -> ToolSpec: """Get a cached tool spec for the given output type. Args: @@ -84,7 +84,7 @@ def tool_type(self) -> str: return "structured_output" @property - def structured_output_model(self) -> Type[BaseModel]: + def structured_output_model(self) -> type[BaseModel]: """Get the Pydantic model type for this tool. Returns: diff --git a/src/strands/tools/structured_output/structured_output_utils.py b/src/strands/tools/structured_output/structured_output_utils.py index 093d67f7c..a78ec6195 100644 --- a/src/strands/tools/structured_output/structured_output_utils.py +++ b/src/strands/tools/structured_output/structured_output_utils.py @@ -1,13 +1,13 @@ """Tools for converting Pydantic models to Bedrock tools.""" -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Union from pydantic import BaseModel from ...types.tools import ToolSpec -def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: +def _flatten_schema(schema: dict[str, Any]) -> dict[str, Any]: """Flattens a JSON schema by removing $defs and resolving $ref references. Handles required vs optional fields properly. @@ -80,11 +80,11 @@ def _flatten_schema(schema: Dict[str, Any]) -> Dict[str, Any]: def _process_property( - prop: Dict[str, Any], - defs: Dict[str, Any], + prop: dict[str, Any], + defs: dict[str, Any], is_required: bool = False, fully_expand: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Process a property in a schema, resolving any references. Args: @@ -174,8 +174,8 @@ def _process_property( def _process_schema_object( - schema_obj: Dict[str, Any], defs: Dict[str, Any], fully_expand: bool = True -) -> Dict[str, Any]: + schema_obj: dict[str, Any], defs: dict[str, Any], fully_expand: bool = True +) -> dict[str, Any]: """Process a schema object, typically from $defs, to resolve all nested properties. Args: @@ -218,7 +218,7 @@ def _process_schema_object( return result -def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, Any]: +def _process_nested_dict(d: dict[str, Any], defs: dict[str, Any]) -> dict[str, Any]: """Recursively processes nested dictionaries and resolves $ref references. Args: @@ -228,7 +228,7 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A Returns: Processed dictionary """ - result: Dict[str, Any] = {} + result: dict[str, Any] = {} # Handle direct reference if "$ref" in d: @@ -258,8 +258,8 @@ def _process_nested_dict(d: Dict[str, Any], defs: Dict[str, Any]) -> Dict[str, A def convert_pydantic_to_tool_spec( - model: Type[BaseModel], - description: Optional[str] = None, + model: type[BaseModel], + description: str | None = None, ) -> ToolSpec: """Converts a Pydantic model to a tool description for the Amazon Bedrock Converse API. @@ -302,7 +302,7 @@ def convert_pydantic_to_tool_spec( ) -def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _expand_nested_properties(schema: dict[str, Any], model: type[BaseModel]) -> None: """Expand the properties of nested models in the schema to include their full structure. This updates the schema in place. @@ -348,7 +348,7 @@ def _expand_nested_properties(schema: Dict[str, Any], model: Type[BaseModel]) -> schema["properties"][prop_name] = expanded_object -def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_referenced_models(schema: dict[str, Any], model: type[BaseModel]) -> None: """Process referenced models to ensure their docstrings are included. This updates the schema in place. @@ -388,7 +388,7 @@ def _process_referenced_models(schema: Dict[str, Any], model: Type[BaseModel]) - _process_properties(ref_def, field_type) -def _process_properties(schema_def: Dict[str, Any], model: Type[BaseModel]) -> None: +def _process_properties(schema_def: dict[str, Any], model: type[BaseModel]) -> None: """Process properties in a schema definition to add descriptions from field metadata. Args: diff --git a/src/strands/tools/watcher.py b/src/strands/tools/watcher.py index 44f2ed512..c7f50fccd 100644 --- a/src/strands/tools/watcher.py +++ b/src/strands/tools/watcher.py @@ -6,7 +6,7 @@ import logging from pathlib import Path -from typing import Any, Dict, Set +from typing import Any from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -25,9 +25,9 @@ class ToolWatcher: # design pattern avoids conflicts when multiple tool registries are watching the same directories. _shared_observer = None - _watched_dirs: Set[str] = set() + _watched_dirs: set[str] = set() _observer_started = False - _registry_handlers: Dict[str, Dict[int, "ToolWatcher.ToolChangeHandler"]] = {} + _registry_handlers: dict[str, dict[int, "ToolWatcher.ToolChangeHandler"]] = {} def __init__(self, tool_registry: ToolRegistry) -> None: """Initialize a tool watcher for the given tool registry. diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index c3890f428..a1f45229b 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -5,7 +5,8 @@ agent lifecycle. """ -from typing import TYPE_CHECKING, Any, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, cast from pydantic import BaseModel from typing_extensions import override diff --git a/src/strands/types/citations.py b/src/strands/types/citations.py index b0e28f655..23959f972 100644 --- a/src/strands/types/citations.py +++ b/src/strands/types/citations.py @@ -3,8 +3,6 @@ These types are modeled after the Bedrock API. """ -from typing import List, Union - from typing_extensions import TypedDict @@ -78,7 +76,7 @@ class DocumentPageLocation(TypedDict, total=False): # Union type for citation locations -CitationLocation = Union[DocumentCharLocation, DocumentChunkLocation, DocumentPageLocation] +CitationLocation = DocumentCharLocation | DocumentChunkLocation | DocumentPageLocation class CitationSourceContent(TypedDict, total=False): @@ -130,7 +128,7 @@ class Citation(TypedDict, total=False): """ location: CitationLocation - sourceContent: List[CitationSourceContent] + sourceContent: list[CitationSourceContent] title: str @@ -148,5 +146,5 @@ class CitationsContentBlock(TypedDict, total=False): citations. """ - citations: List[Citation] - content: List[CitationGeneratedContent] + citations: list[Citation] + content: list[CitationGeneratedContent] diff --git a/src/strands/types/collections.py b/src/strands/types/collections.py index df857ace0..28b4a1891 100644 --- a/src/strands/types/collections.py +++ b/src/strands/types/collections.py @@ -1,6 +1,6 @@ """Generic collection types for the Strands SDK.""" -from typing import Generic, List, Optional, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -12,7 +12,7 @@ class PaginatedList(list, Generic[T]): so existing code that expects List[T] will continue to work. """ - def __init__(self, data: List[T], token: Optional[str] = None): + def __init__(self, data: list[T], token: str | None = None): """Initialize a PaginatedList with data and an optional pagination token. Args: diff --git a/src/strands/types/content.py b/src/strands/types/content.py index 4d0bbe412..d75dbb87f 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -6,7 +6,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -23,7 +23,7 @@ class GuardContentText(TypedDict): text: The input text details to be evaluated by the guardrail. """ - qualifiers: List[Literal["grounding_source", "query", "guard_content"]] + qualifiers: list[Literal["grounding_source", "query", "guard_content"]] text: str @@ -45,7 +45,7 @@ class ReasoningTextBlock(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - signature: Optional[str] + signature: str | None text: str @@ -120,7 +120,7 @@ class DeltaContent(TypedDict, total=False): """ text: str - toolUse: Dict[Literal["input"], str] + toolUse: dict[Literal["input"], str] class ContentBlockStartToolUse(TypedDict): @@ -142,7 +142,7 @@ class ContentBlockStart(TypedDict, total=False): toolUse: Information about a tool that the model is requesting to use. """ - toolUse: Optional[ContentBlockStartToolUse] + toolUse: ContentBlockStartToolUse | None class ContentBlockDelta(TypedDict): @@ -183,9 +183,9 @@ class Message(TypedDict): role: The role of the message sender. """ - content: List[ContentBlock] + content: list[ContentBlock] role: Role -Messages = List[Message] +Messages = list[Message] """A list of messages representing a conversation.""" diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index c15ba1bea..70a7aedd5 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Dict, List, Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -22,7 +22,7 @@ class GuardrailConfig(TypedDict, total=False): guardrailIdentifier: str guardrailVersion: str - streamProcessingMode: Optional[Literal["sync", "async"]] + streamProcessingMode: Literal["sync", "async"] | None trace: Literal["enabled", "disabled"] @@ -47,7 +47,7 @@ class TopicPolicy(TypedDict): topics: The topics in the assessment. """ - topics: List[Topic] + topics: list[Topic] class ContentFilter(TypedDict): @@ -71,7 +71,7 @@ class ContentPolicy(TypedDict): filters: List of content filters to apply. """ - filters: List[ContentFilter] + filters: list[ContentFilter] class CustomWord(TypedDict): @@ -108,8 +108,8 @@ class WordPolicy(TypedDict): managedWordLists: List of managed word lists to filter. """ - customWords: List[CustomWord] - managedWordLists: List[ManagedWord] + customWords: list[CustomWord] + managedWordLists: list[ManagedWord] class PIIEntity(TypedDict): @@ -182,8 +182,8 @@ class SensitiveInformationPolicy(TypedDict): regexes: The regex queries in the assessment. """ - piiEntities: List[PIIEntity] - regexes: List[Regex] + piiEntities: list[PIIEntity] + regexes: list[Regex] class ContextualGroundingFilter(TypedDict): @@ -209,7 +209,7 @@ class ContextualGroundingPolicy(TypedDict): filters: The filter details for the guardrails contextual grounding filter. """ - filters: List[ContextualGroundingFilter] + filters: list[ContextualGroundingFilter] class GuardrailAssessment(TypedDict): @@ -239,9 +239,9 @@ class GuardrailTrace(TypedDict): outputAssessments: Assessments of output content against guardrail policies, keyed by output identifier. """ - inputAssessment: Dict[str, GuardrailAssessment] - modelOutput: List[str] - outputAssessments: Dict[str, List[GuardrailAssessment]] + inputAssessment: dict[str, GuardrailAssessment] + modelOutput: list[str] + outputAssessments: dict[str, list[GuardrailAssessment]] class Trace(TypedDict): diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 69cd60cf3..462d8af34 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -5,7 +5,7 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Literal, Optional +from typing import Literal from typing_extensions import TypedDict @@ -37,8 +37,8 @@ class DocumentContent(TypedDict, total=False): format: Literal["pdf", "csv", "doc", "docx", "xls", "xlsx", "html", "txt", "md"] name: str source: DocumentSource - citations: Optional[CitationsConfig] - context: Optional[str] + citations: CitationsConfig | None + context: str | None ImageFormat = Literal["png", "jpeg", "gif", "webp"] diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 5da3dcde8..29453f4b7 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -5,7 +5,7 @@ from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from ..interrupt import _InterruptState from .content import Message @@ -69,7 +69,7 @@ class SessionMessage: message: Message message_id: int - redact_message: Optional[Message] = None + redact_message: Message | None = None created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index dcfd541a8..8ec2e8d7b 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -5,8 +5,6 @@ - Bedrock docs: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Types_Amazon_Bedrock_Runtime.html """ -from typing import Optional, Union - from typing_extensions import TypedDict from .citations import CitationLocation @@ -34,7 +32,7 @@ class ContentBlockStartEvent(TypedDict, total=False): start: Information about the content block being started. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None start: ContentBlockStart @@ -102,9 +100,9 @@ class ReasoningContentBlockDelta(TypedDict, total=False): text: The reasoning that the model used to return the output. """ - redactedContent: Optional[bytes] - signature: Optional[str] - text: Optional[str] + redactedContent: bytes | None + signature: str | None + text: str | None class ContentBlockDelta(TypedDict, total=False): @@ -131,7 +129,7 @@ class ContentBlockDeltaEvent(TypedDict, total=False): delta: The incremental content update for the content block. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None delta: ContentBlockDelta @@ -143,7 +141,7 @@ class ContentBlockStopEvent(TypedDict, total=False): This is optional to accommodate different model providers. """ - contentBlockIndex: Optional[int] + contentBlockIndex: int | None class MessageStopEvent(TypedDict, total=False): @@ -154,7 +152,7 @@ class MessageStopEvent(TypedDict, total=False): stopReason: The reason why the model stopped generating content. """ - additionalModelResponseFields: Optional[Union[dict, list, int, float, str, bool, None]] + additionalModelResponseFields: dict | list | int | float | str | bool | None | None stopReason: StopReason @@ -168,7 +166,7 @@ class MetadataEvent(TypedDict, total=False): """ metrics: Metrics - trace: Optional[Trace] + trace: Trace | None usage: Usage @@ -203,8 +201,8 @@ class RedactContentEvent(TypedDict, total=False): """ - redactUserContentMessage: Optional[str] - redactAssistantContentMessage: Optional[str] + redactUserContentMessage: str | None + redactAssistantContentMessage: str | None class StreamEvent(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 8f4dba6b1..6fc0d703c 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -7,8 +7,9 @@ import uuid from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass -from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Protocol, Union +from typing import Any, Literal, Protocol from typing_extensions import NotRequired, TypedDict @@ -164,11 +165,7 @@ def _interrupt_id(self, name: str) -> str: ToolChoiceAnyDict = dict[Literal["any"], ToolChoiceAny] ToolChoiceToolDict = dict[Literal["tool"], ToolChoiceTool] -ToolChoice = Union[ - ToolChoiceAutoDict, - ToolChoiceAnyDict, - ToolChoiceToolDict, -] +ToolChoice = ToolChoiceAutoDict | ToolChoiceAnyDict | ToolChoiceToolDict """ Configuration for how the model should choose tools. @@ -201,12 +198,7 @@ class ToolFunc(Protocol): __name__: str - def __call__( - self, *args: Any, **kwargs: Any - ) -> Union[ - ToolResult, - Awaitable[ToolResult], - ]: + def __call__(self, *args: Any, **kwargs: Any) -> ToolResult | Awaitable[ToolResult]: """Function signature for Python decorated and module based tools. Returns: diff --git a/src/strands/types/traces.py b/src/strands/types/traces.py index af6188adb..fceec27a9 100644 --- a/src/strands/types/traces.py +++ b/src/strands/types/traces.py @@ -1,20 +1,19 @@ """Tracing type definitions for the SDK.""" -from typing import List, Mapping, Optional, Sequence, Union +from collections.abc import Mapping, Sequence -AttributeValue = Union[ - str, - bool, - float, - int, - List[str], - List[bool], - List[float], - List[int], - Sequence[str], - Sequence[bool], - Sequence[int], - Sequence[float], -] - -Attributes = Optional[Mapping[str, AttributeValue]] +AttributeValue = ( + str + | bool + | float + | int + | list[str] + | list[bool] + | list[float] + | list[int] + | Sequence[str] + | Sequence[bool] + | Sequence[int] + | Sequence[float] +) +Attributes = Mapping[str, AttributeValue] | None diff --git a/tests/fixtures/mock_hook_provider.py b/tests/fixtures/mock_hook_provider.py index 091f44d06..cf17bb470 100644 --- a/tests/fixtures/mock_hook_provider.py +++ b/tests/fixtures/mock_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands import Agent from strands.hooks import ( @@ -17,7 +18,7 @@ class MockHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ AgentInitializedEvent, @@ -37,7 +38,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 727d28a48..4d18297a2 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,4 +1,5 @@ -from typing import Iterator, Literal, Tuple, Type +from collections.abc import Iterator +from typing import Literal from strands.experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, @@ -14,7 +15,7 @@ class MockMultiAgentHookProvider(HookProvider): - def __init__(self, event_types: list[Type] | Literal["all"]): + def __init__(self, event_types: list[type] | Literal["all"]): if event_types == "all": event_types = [ MultiAgentInitializedEvent, @@ -30,7 +31,7 @@ def __init__(self, event_types: list[Type] | Literal["all"]): def event_types_received(self): return [type(event) for event in self.events_received] - def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + def get_events(self) -> tuple[int, Iterator[HookEvent]]: return len(self.events_received), iter(self.events_received) def register_hooks(self, registry: HookRegistry) -> None: diff --git a/tests/fixtures/mocked_model_provider.py b/tests/fixtures/mocked_model_provider.py index 24de958bc..f1c5cae77 100644 --- a/tests/fixtures/mocked_model_provider.py +++ b/tests/fixtures/mocked_model_provider.py @@ -1,5 +1,6 @@ import json -from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union +from collections.abc import AsyncGenerator, Iterable, Sequence +from typing import Any, TypedDict, TypeVar from pydantic import BaseModel @@ -25,7 +26,7 @@ class MockedModelProvider(Model): to stream mock responses as events. """ - def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]): + def __init__(self, agent_responses: Sequence[Message | RedactionMessage]): self.agent_responses = [*agent_responses] self.index = 0 @@ -33,7 +34,7 @@ def format_chunk(self, event: Any) -> StreamEvent: return event def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + self, messages: Messages, tool_specs: list[ToolSpec] | None = None, system_prompt: str | None = None ) -> Any: return None @@ -45,9 +46,9 @@ def update_config(self, **model_config: Any) -> None: async def structured_output( self, - output_model: Type[T], + output_model: type[T], prompt: Messages, - system_prompt: Optional[str] = None, + system_prompt: str | None = None, **kwargs: Any, ) -> AsyncGenerator[Any, None]: pass @@ -55,9 +56,9 @@ async def structured_output( async def stream( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, - tool_choice: Optional[Any] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + tool_choice: Any | None = None, *, system_prompt_content=None, **kwargs: Any, @@ -68,7 +69,7 @@ async def stream( self.index += 1 - def map_agent_message_to_events(self, agent_message: Union[Message, RedactionMessage]) -> Iterable[dict[str, Any]]: + def map_agent_message_to_events(self, agent_message: Message | RedactionMessage) -> Iterable[dict[str, Any]]: stop_reason: StopReason = "end_turn" yield {"messageStart": {"role": "assistant"}} if agent_message.get("redactedAssistantContent"): diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index ad1415f22..12b5af42c 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -1,6 +1,5 @@ import unittest.mock from dataclasses import dataclass -from typing import List from unittest.mock import MagicMock, Mock import pytest @@ -139,7 +138,7 @@ async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, nor @pytest.mark.asyncio async def test_invoke_callbacks_async_after_event(hook_registry, after_event): """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" - call_order: List[str] = [] + call_order: list[str] = [] def callback1(_event): call_order.append("callback1") diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 5d1f02089..1ec0a8407 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Optional, cast +from typing import cast import pytest from pydantic import BaseModel @@ -150,7 +150,7 @@ class StructuredOutputModel(BaseModel): name: str value: int - optional_field: Optional[str] = None + optional_field: str | None = None def test__init__with_structured_output(mock_metrics, simple_message: Message): diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py index b679faed0..7341c714e 100644 --- a/tests/strands/agent/test_agent_structured_output.py +++ b/tests/strands/agent/test_agent_structured_output.py @@ -1,6 +1,5 @@ """Tests for Agent structured output functionality.""" -from typing import Optional from unittest import mock from unittest.mock import Mock, patch @@ -28,7 +27,7 @@ class ProductModel(BaseModel): title: str price: float - description: Optional[str] = None + description: str | None = None @pytest.fixture diff --git a/tests/strands/models/test_sagemaker.py b/tests/strands/models/test_sagemaker.py index 72ebf01c6..5d6d6869a 100644 --- a/tests/strands/models/test_sagemaker.py +++ b/tests/strands/models/test_sagemaker.py @@ -2,7 +2,7 @@ import json import unittest.mock -from typing import Any, Dict, List +from typing import Any import boto3 import pytest @@ -32,7 +32,7 @@ def sagemaker_client(boto_session): @pytest.fixture -def endpoint_config() -> Dict[str, Any]: +def endpoint_config() -> dict[str, Any]: """Default endpoint configuration for tests.""" return { "endpoint_name": "test-endpoint", @@ -42,7 +42,7 @@ def endpoint_config() -> Dict[str, Any]: @pytest.fixture -def payload_config() -> Dict[str, Any]: +def payload_config() -> dict[str, Any]: """Default payload configuration for tests.""" return { "max_tokens": 1024, @@ -64,7 +64,7 @@ def messages() -> Messages: @pytest.fixture -def tool_specs() -> List[ToolSpec]: +def tool_specs() -> list[ToolSpec]: """Sample tool specifications for testing.""" return [ { @@ -405,8 +405,8 @@ async def test_stream_with_partial_json(self, sagemaker_client, model, messages, # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response @@ -444,8 +444,8 @@ async def test_tool_choice_not_supported_warns(self, sagemaker_client, model, me # Mock the response from SageMaker with split JSON mock_response = { "Body": [ - {"PayloadPart": {"Bytes": '{"choices": [{"delta": {"content": "Paris is'.encode("utf-8")}}, - {"PayloadPart": {"Bytes": ' the capital of France."}, "finish_reason": "stop"}]}'.encode("utf-8")}}, + {"PayloadPart": {"Bytes": b'{"choices": [{"delta": {"content": "Paris is'}}, + {"PayloadPart": {"Bytes": b' the capital of France."}, "finish_reason": "stop"}]}'}}, ] } sagemaker_client.invoke_endpoint_with_response_stream.return_value = mock_response diff --git a/tests/strands/models/test_writer.py b/tests/strands/models/test_writer.py index 8cf64a39a..963904002 100644 --- a/tests/strands/models/test_writer.py +++ b/tests/strands/models/test_writer.py @@ -1,5 +1,5 @@ import unittest.mock -from typing import Any, List +from typing import Any import pytest @@ -266,7 +266,7 @@ def test_format_request_with_unsupported_type(model, content, content_type): class AsyncStreamWrapper: - def __init__(self, items: List[Any]): + def __init__(self, items: list[Any]): self.items = items def __aiter__(self): @@ -277,7 +277,7 @@ async def _generator(self): yield item -async def mock_streaming_response(items: List[Any]): +async def mock_streaming_response(items: list[Any]): return AsyncStreamWrapper(items) diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 7e28be998..8e14c9adc 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -82,7 +82,7 @@ def test_create_session(file_manager, sample_session): assert os.path.exists(session_file) # Verify content - with open(session_file, "r") as f: + with open(session_file) as f: data = json.load(f) assert data["session_id"] == sample_session.session_id assert data["session_type"] == sample_session.session_type @@ -144,7 +144,7 @@ def test_create_agent(file_manager, sample_session, sample_agent): assert os.path.exists(agent_file) # Verify content - with open(agent_file, "r") as f: + with open(agent_file) as f: data = json.load(f) assert data["agent_id"] == sample_agent.agent_id assert data["state"] == sample_agent.state @@ -210,7 +210,7 @@ def test_create_message(file_manager, sample_session, sample_agent, sample_messa assert os.path.exists(message_path) # Verify content - with open(message_path, "r") as f: + with open(message_path) as f: data = json.load(f) assert data["message_id"] == sample_message.message_id @@ -439,7 +439,7 @@ def test_create_multi_agent(multi_agent_manager, sample_session, mock_multi_agen assert os.path.exists(multi_agent_file) # Verify content - with open(multi_agent_file, "r") as f: + with open(multi_agent_file) as f: data = json.load(f) assert data["id"] == mock_multi_agent.id assert data["state"] == mock_multi_agent.state diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py index a7eb27ca5..0f1c7ffff 100644 --- a/tests/strands/tools/structured_output/test_structured_output_context.py +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -1,7 +1,5 @@ """Tests for StructuredOutputContext class.""" -from typing import Optional - from pydantic import BaseModel, Field from strands.tools.structured_output._structured_output_context import StructuredOutputContext @@ -13,7 +11,7 @@ class SampleModel(BaseModel): name: str = Field(..., description="Name field") age: int = Field(..., description="Age field", ge=0) - email: Optional[str] = Field(None, description="Optional email field") + email: str | None = Field(None, description="Optional email field") class AnotherSampleModel(BaseModel): diff --git a/tests/strands/tools/structured_output/test_structured_output_tool.py b/tests/strands/tools/structured_output/test_structured_output_tool.py index 66f1d465d..784a508bd 100644 --- a/tests/strands/tools/structured_output/test_structured_output_tool.py +++ b/tests/strands/tools/structured_output/test_structured_output_tool.py @@ -1,6 +1,5 @@ """Tests for StructuredOutputTool class.""" -from typing import List, Optional from unittest.mock import MagicMock import pytest @@ -23,8 +22,8 @@ class ComplexModel(BaseModel): title: str = Field(..., description="Title field") count: int = Field(..., ge=0, le=100, description="Count between 0 and 100") - tags: List[str] = Field(default_factory=list, description="List of tags") - metadata: Optional[dict] = Field(None, description="Optional metadata") + tags: list[str] = Field(default_factory=list, description="List of tags") + metadata: dict | None = Field(None, description="Optional metadata") class ValidationTestModel(BaseModel): diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index a2a4c6213..4757e5587 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,7 +3,8 @@ """ from asyncio import Queue -from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator +from typing import Annotated, Any from unittest.mock import MagicMock import pytest @@ -267,7 +268,7 @@ async def test_tool_with_optional_params(alist): """Test tool decorator with optional parameters.""" @strands.tool - def test_tool(required: str, optional: Optional[int] = None) -> str: + def test_tool(required: str, optional: int | None = None) -> str: """Test with optional param. Args: @@ -864,7 +865,7 @@ def int_return_tool(param: str) -> int: # Define tool with Union return type @strands.tool - def union_return_tool(param: str) -> Union[Dict[str, Any], str, None]: + def union_return_tool(param: str) -> dict[str, Any] | str | None: """Tool with Union return type. Args: @@ -936,7 +937,7 @@ async def test_complex_parameter_types(alist): """Test handling of complex parameter types like nested dictionaries.""" @strands.tool - def complex_type_tool(config: Dict[str, Any]) -> str: + def complex_type_tool(config: dict[str, Any]) -> str: """Tool with complex parameter type. Args: @@ -965,7 +966,7 @@ async def test_custom_tool_result_handling(alist): """Test that a function returning a properly formatted tool result dictionary is handled correctly.""" @strands.tool - def custom_result_tool(param: str) -> Dict[str, Any]: + def custom_result_tool(param: str) -> dict[str, Any]: """Tool that returns a custom tool result dictionary. Args: @@ -1079,11 +1080,11 @@ def validation_tool(str_param: str, int_param: int, bool_param: bool) -> str: @pytest.mark.asyncio async def test_tool_complex_validation_edge_cases(alist): """Test validation of complex schema edge cases.""" - from typing import Any, Dict, Union + from typing import Any # Define a tool with a complex anyOf type that could trigger edge case handling @strands.tool - def edge_case_tool(param: Union[Dict[str, Any], None]) -> str: + def edge_case_tool(param: dict[str, Any] | None) -> str: """Tool with complex anyOf structure. Args: @@ -1236,10 +1237,10 @@ def failing_tool(param: str) -> str: @pytest.mark.asyncio async def test_tool_with_complex_anyof_schema(alist): """Test handling of complex anyOf structures in the schema.""" - from typing import Any, Dict, List, Union + from typing import Any @strands.tool - def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]) -> str: + def complex_schema_tool(union_param: list[int] | dict[str, Any] | str | None) -> str: """Tool with a complex Union type that creates anyOf in schema. Args: @@ -1680,7 +1681,7 @@ def test_tool_decorator_annotated_optional_type(): @strands.tool def optional_annotated_tool( - required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + required: Annotated[str, "Required parameter"], optional: Annotated[str | None, "Optional parameter"] = None ) -> str: """Tool with optional annotated parameter.""" return f"{required}, {optional}" @@ -1702,7 +1703,7 @@ def test_tool_decorator_annotated_complex_types(): @strands.tool def complex_annotated_tool( - tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + tags: Annotated[list[str], "List of tag strings"], config: Annotated[dict[str, Any], "Configuration dictionary"] ) -> str: """Tool with complex annotated types.""" return f"Tags: {len(tags)}, Config: {len(config)}" diff --git a/tests/strands/tools/test_structured_output.py b/tests/strands/tools/test_structured_output.py index fe9b55334..72a53bfe6 100644 --- a/tests/strands/tools/test_structured_output.py +++ b/tests/strands/tools/test_structured_output.py @@ -1,4 +1,4 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import pytest from pydantic import BaseModel, Field @@ -27,7 +27,7 @@ class TwoUsersWithPlanet(BaseModel): """Two users model with planet.""" user1: UserWithPlanet = Field(description="The first user") - user2: Optional[UserWithPlanet] = Field(description="The second user", default=None) + user2: UserWithPlanet | None = Field(description="The second user", default=None) # Test model with list of same type fields @@ -250,8 +250,8 @@ class NodeWithCircularRef(BaseModel): def test_conversion_works_with_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) expected_output = { @@ -281,8 +281,8 @@ class Family(BaseModel): def test_marks_fields_as_optional_for_model_w_fields_that_are_not_marked_as_optional_but_have_a_default_value_which_makes_them_optional(): # noqa E501 class Family(BaseModel): - ages: List[str] = Field(default_factory=list) - names: List[str] = Field(default_factory=list) + ages: list[str] = Field(default_factory=list) + names: list[str] = Field(default_factory=list) converted_output = convert_pydantic_to_tool_spec(Family) assert "null" in converted_output["inputSchema"]["json"]["properties"]["ages"]["type"] @@ -312,14 +312,14 @@ def test_convert_pydantic_with_items_refs(): """Test that no $refs exist after lists of different components.""" class Address(BaseModel): - postal_code: Optional[str] = None + postal_code: str | None = None class Person(BaseModel): """Complete person information.""" list_of_items: list[Address] - list_of_items_nullable: Optional[list[Address]] - list_of_item_or_nullable: list[Optional[Address]] + list_of_items_nullable: list[Address] | None + list_of_item_or_nullable: list[Address | None] tool_spec = convert_pydantic_to_tool_spec(Person) @@ -378,7 +378,7 @@ class Address(BaseModel): street: str city: str country: str - postal_code: Optional[str] = None + postal_code: str | None = None class Contact(BaseModel): address: Address diff --git a/tests_integ/mcp/echo_server.py b/tests_integ/mcp/echo_server.py index e15065a4a..37557cf11 100644 --- a/tests_integ/mcp/echo_server.py +++ b/tests_integ/mcp/echo_server.py @@ -71,9 +71,7 @@ def get_weather(location: Literal["New York", "London", "Tokyo"] = "New York"): resource=BlobResourceContents( uri="https://weather.api/data/london.json", mimeType="application/json", - blob=base64.b64encode( - '{"temperature": 18, "condition": "rainy", "humidity": 85}'.encode() - ).decode(), + blob=base64.b64encode(b'{"temperature": 18, "condition": "rainy", "humidity": 85}').decode(), ), ) ] diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 5c3baeba8..298272df5 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -3,7 +3,7 @@ import os import threading import time -from typing import List, Literal +from typing import Literal import pytest from mcp import StdioServerParameters, stdio_client @@ -47,7 +47,7 @@ def generate_custom_image() -> MCPImageContent: encoded_image = base64.b64encode(image_file.read()) return MCPImageContent(type="image", data=encoded_image, mimeType="image/png") except Exception as e: - print("Error while generating custom image: {}".format(e)) + print(f"Error while generating custom image: {e}") # Prompts @mcp.prompt(description="A greeting prompt template") @@ -366,7 +366,7 @@ def test_mcp_client_embedded_resources_with_agent(): assert any(["72" in response_text, "partly cloudy" in response_text, "weather" in response_text]) -def _messages_to_content_blocks(messages: List[Message]) -> List[ToolUse]: +def _messages_to_content_blocks(messages: list[Message]) -> list[ToolUse]: return [block["toolUse"] for message in messages for block in message["content"] if "toolUse" in block] diff --git a/tests_integ/models/providers.py b/tests_integ/models/providers.py index 75cc58f74..57614b97f 100644 --- a/tests_integ/models/providers.py +++ b/tests_integ/models/providers.py @@ -3,7 +3,7 @@ """ import os -from typing import Callable, Optional +from collections.abc import Callable import requests from pytest import mark @@ -26,7 +26,7 @@ def __init__( self, id: str, factory: Callable[[], Model], - environment_variable: Optional[str] = None, + environment_variable: str | None = None, ) -> None: self.id = id self.model_factory = factory diff --git a/tests_integ/test_function_tools.py b/tests_integ/test_function_tools.py index 835dccf5d..6c72bdddb 100644 --- a/tests_integ/test_function_tools.py +++ b/tests_integ/test_function_tools.py @@ -4,7 +4,6 @@ """ import logging -from typing import Optional from strands import Agent, tool @@ -25,7 +24,7 @@ def word_counter(text: str) -> str: @tool(name="count_chars", description="Count characters in text") -def count_chars(text: str, include_spaces: Optional[bool] = True) -> str: +def count_chars(text: str, include_spaces: bool | None = True) -> str: """ Count characters in text. diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 08343a554..b80a0f82d 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,4 +1,5 @@ -from typing import Any, AsyncIterator +from collections.abc import AsyncIterator +from typing import Any from unittest.mock import patch from uuid import uuid4 diff --git a/tests_integ/test_structured_output_agent_loop.py b/tests_integ/test_structured_output_agent_loop.py index 188f57777..3ddc6db95 100644 --- a/tests_integ/test_structured_output_agent_loop.py +++ b/tests_integ/test_structured_output_agent_loop.py @@ -2,8 +2,6 @@ Comprehensive integration tests for structured output passed into the agent functionality. """ -from typing import List, Optional - import pytest from pydantic import BaseModel, Field, field_validator @@ -42,7 +40,7 @@ class Contact(BaseModel): """Contact information.""" email: str - phone: Optional[str] = None + phone: str | None = None preferred_method: str = "email" @@ -54,7 +52,7 @@ class Employee(BaseModel): department: str address: Address contact: Contact - skills: List[str] + skills: list[str] hire_date: str salary_range: str @@ -65,7 +63,7 @@ class ProductReview(BaseModel): product_name: str rating: int = Field(ge=1, le=5, description="Rating from 1-5 stars") sentiment: str = Field(pattern="^(positive|negative|neutral)$") - key_points: List[str] + key_points: list[str] would_recommend: bool @@ -84,7 +82,7 @@ class TaskList(BaseModel): """Task management structure.""" project_name: str - tasks: List[str] + tasks: list[str] priority: str = Field(pattern="^(high|medium|low)$") due_date: str estimated_hours: int @@ -102,7 +100,7 @@ class Company(BaseModel): name: str = Field(description="Company name") address: Address = Field(description="Company address") - employees: List[Person] = Field(description="list of persons") + employees: list[Person] = Field(description="list of persons") class Task(BaseModel):