Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ select = [
"G", # logging format
"I", # isort
"LOG", # logging
"UP", # pyupgrade
]

[tool.ruff.lint.per-file-ignores]
Expand Down
3 changes: 2 additions & 1 deletion src/strands/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import asyncio
import contextvars
from collections.abc import Awaitable, Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Awaitable, Callable, TypeVar
from typing import TypeVar

T = TypeVar("T")

Expand Down
59 changes: 26 additions & 33 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@

import logging
import warnings
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Mapping,
Optional,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -104,26 +99,24 @@ class Agent:

def __init__(
self,
model: Union[Model, str, None] = None,
messages: Optional[Messages] = None,
tools: Optional[list[Union[str, dict[str, str], "ToolProvider", Any]]] = None,
system_prompt: Optional[str | list[SystemContentBlock]] = None,
structured_output_model: Optional[Type[BaseModel]] = None,
callback_handler: Optional[
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
] = _DEFAULT_CALLBACK_HANDLER,
conversation_manager: Optional[ConversationManager] = None,
model: Model | str | None = None,
messages: Messages | None = None,
tools: list[Union[str, dict[str, str], "ToolProvider", Any]] | None = None,
system_prompt: str | list[SystemContentBlock] | None = None,
structured_output_model: type[BaseModel] | None = None,
callback_handler: Callable[..., Any] | _DefaultCallbackHandlerSentinel | None = _DEFAULT_CALLBACK_HANDLER,
conversation_manager: ConversationManager | None = None,
record_direct_tool_call: bool = True,
load_tools_from_directory: bool = False,
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
trace_attributes: Mapping[str, AttributeValue] | None = None,
*,
agent_id: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
state: Optional[Union[AgentState, dict]] = None,
hooks: Optional[list[HookProvider]] = None,
session_manager: Optional[SessionManager] = None,
tool_executor: Optional[ToolExecutor] = None,
agent_id: str | None = None,
name: str | None = None,
description: str | None = None,
state: AgentState | dict | None = None,
hooks: list[HookProvider] | None = None,
session_manager: SessionManager | None = None,
tool_executor: ToolExecutor | None = None,
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -189,7 +182,7 @@ def __init__(
# If not provided, create a new PrintingCallbackHandler instance
# If explicitly set to None, use null_callback_handler
# Otherwise use the passed callback_handler
self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler]
self.callback_handler: Callable[..., Any] | PrintingCallbackHandler
if isinstance(callback_handler, _DefaultCallbackHandlerSentinel):
self.callback_handler = PrintingCallbackHandler()
elif callback_handler is None:
Expand Down Expand Up @@ -226,7 +219,7 @@ def __init__(

# Initialize tracer instance (no-op if not configured)
self.tracer = get_tracer()
self.trace_span: Optional[trace_api.Span] = None
self.trace_span: trace_api.Span | None = None

# Initialize agent state management
if state is not None:
Expand Down Expand Up @@ -316,7 +309,7 @@ def __call__(
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
structured_output_model: type[BaseModel] | None = None,
**kwargs: Any,
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand Down Expand Up @@ -357,7 +350,7 @@ async def invoke_async(
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
structured_output_model: type[BaseModel] | None = None,
**kwargs: Any,
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand Down Expand Up @@ -394,7 +387,7 @@ async def invoke_async(

return cast(AgentResult, event["result"])

def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
def structured_output(self, output_model: type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
Expand Down Expand Up @@ -425,7 +418,7 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) ->

return run_async(lambda: self.structured_output_async(output_model, prompt))

async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
async def structured_output_async(self, output_model: type[T], prompt: AgentInput = None) -> T:
"""This method allows you to get structured output from the agent.

If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
Expand Down Expand Up @@ -520,7 +513,7 @@ async def stream_async(
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
structured_output_model: type[BaseModel] | None = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand Down Expand Up @@ -607,7 +600,7 @@ async def _run_loop(
self,
messages: Messages,
invocation_state: dict[str, Any],
structured_output_model: Type[BaseModel] | None = None,
structured_output_model: type[BaseModel] | None = None,
) -> AsyncGenerator[TypedEvent, None]:
"""Execute the agent's event loop with the given message and parameters.

Expand Down Expand Up @@ -769,8 +762,8 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span:

def _end_agent_trace_span(
self,
response: Optional[AgentResult] = None,
error: Optional[Exception] = None,
response: AgentResult | None = None,
error: Exception | None = None,
) -> None:
"""Ends a trace span for the agent.

Expand Down
3 changes: 2 additions & 1 deletion src/strands/agent/agent_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
This module defines the AgentResult class which encapsulates the complete response from an agent's processing cycle.
"""

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Sequence, cast
from typing import Any, cast

from pydantic import BaseModel

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Abstract interface for conversation history management."""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any

from ...types.content import Message

Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(self) -> None:
"""
self.removed_message_count = 0

def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None:
"""Restore the Conversation Manager's state from a session.

Args:
Expand Down Expand Up @@ -66,7 +66,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
pass

@abstractmethod
def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None:
"""Called when the model's context window is exceeded.

This method should implement the specific strategy for reducing the window size when a context overflow occurs.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Null implementation of conversation management."""

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from ...agent.agent import Agent
Expand Down Expand Up @@ -28,7 +28,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""
pass

def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None:
"""Does not reduce context and raises an exception.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Sliding window conversation history management."""

import logging
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from ...agent.agent import Agent
Expand Down Expand Up @@ -52,7 +52,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
return
self.reduce_context(agent)

def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None:
"""Trim the oldest messages to reduce the conversation context size.

The method handles special cases where trimming the messages leads to:
Expand Down Expand Up @@ -151,7 +151,7 @@ def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:

return changes_made

def _find_last_message_with_tool_results(self, messages: Messages) -> Optional[int]:
def _find_last_message_with_tool_results(self, messages: Messages) -> int | None:
"""Find the index of the last message containing tool results.

This is useful for identifying messages that might need to be truncated to reduce context size.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Summarizing conversation history management with configurable options."""

import logging
from typing import TYPE_CHECKING, Any, List, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from typing_extensions import override

Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
summary_ratio: float = 0.3,
preserve_recent_messages: int = 10,
summarization_agent: Optional["Agent"] = None,
summarization_system_prompt: Optional[str] = None,
summarization_system_prompt: str | None = None,
):
"""Initialize the summarizing conversation manager.

Expand All @@ -87,10 +87,10 @@ def __init__(
self.preserve_recent_messages = preserve_recent_messages
self.summarization_agent = summarization_agent
self.summarization_system_prompt = summarization_system_prompt
self._summary_message: Optional[Message] = None
self._summary_message: Message | None = None

@override
def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None:
"""Restores the Summarizing Conversation manager from its previous state in a session.

Args:
Expand Down Expand Up @@ -121,7 +121,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
# No proactive management - summarization only happens on context overflow
pass

def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None:
"""Reduce context using summarization.

Args:
Expand Down Expand Up @@ -173,7 +173,7 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs
logger.error("Summarization failed: %s", summarization_error)
raise summarization_error from e

def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message:
def _generate_summary(self, messages: list[Message], agent: "Agent") -> Message:
"""Generate a summary of the provided messages.

Args:
Expand Down Expand Up @@ -224,7 +224,7 @@ def _generate_summary(self, messages: List[Message], agent: "Agent") -> Message:
summarization_agent.messages = original_messages
summarization_agent.tool_registry = original_tool_registry

def _adjust_split_point_for_tool_pairs(self, messages: List[Message], split_point: int) -> int:
def _adjust_split_point_for_tool_pairs(self, messages: list[Message], split_point: int) -> int:
"""Adjust the split point to avoid breaking ToolUse/ToolResult pairs.

Uses the same logic as SlidingWindowConversationManager for consistency.
Expand Down
3 changes: 2 additions & 1 deletion src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import asyncio
import logging
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any

from opentelemetry import trace as trace_api

Expand Down
9 changes: 5 additions & 4 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import logging
import time
import warnings
from typing import Any, AsyncGenerator, AsyncIterable, Optional
from collections.abc import AsyncGenerator, AsyncIterable
from typing import Any

from ..models.model import Model
from ..tools import InvalidToolUseNameException
Expand Down Expand Up @@ -418,12 +419,12 @@ async def process_stream(

async def stream_messages(
model: Model,
system_prompt: Optional[str],
system_prompt: str | None,
messages: Messages,
tool_specs: list[ToolSpec],
*,
tool_choice: Optional[Any] = None,
system_prompt_content: Optional[list[SystemContentBlock]] = None,
tool_choice: Any | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
**kwargs: Any,
) -> AsyncGenerator[TypedEvent, None]:
"""Streams messages to the model and processes the response.
Expand Down
2 changes: 1 addition & 1 deletion src/strands/experimental/agent_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def config_to_agent(config: str | dict[str, Any], **kwargs: dict[str, Any]) -> A
if not config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {file_path}")

with open(config_path, "r") as f:
with open(config_path) as f:
config_dict = json.load(f)
elif isinstance(config, dict):
config_dict = config.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self.prompt_mapper = prompt_mapper or DefaultPromptMapper()
self.model = model

async def steer(self, agent: "Agent", tool_use: ToolUse, **kwargs: Any) -> SteeringAction:
async def steer(self, agent: Agent, tool_use: ToolUse, **kwargs: Any) -> SteeringAction:
"""Provide contextual guidance for tool usage.

Args:
Expand Down
3 changes: 2 additions & 1 deletion src/strands/experimental/tools/tool_provider.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Tool provider interface."""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from ...types.tools import AgentTool
Expand Down
Loading
Loading