Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,11 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
self.conversation_manager.reduce_context(self, e=e)

# Sync agent after reduce_context to keep conversation_manager_state up to date in the session
if self._session_manager:
self._session_manager.sync_agent(self)

events = self._execute_event_loop_cycle(invocation_state)
async for event in events:
yield event
Expand Down
34 changes: 32 additions & 2 deletions src/strands/agent/conversation_manager/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from ...types.content import Message

if TYPE_CHECKING:
from ...agent.agent import Agent

Expand All @@ -18,8 +20,37 @@ class ConversationManager(ABC):
- Maintain relevant conversation state
"""

def __init__(self) -> None:
"""Initialize the ConversationManager.

Attributes:
removed_message_count: The messages that have been removed from the agents messages array.
These represent messages provided by the user or LLM that have been removed, not messages
included by the conversation manager through something like summarization.
"""
self.removed_message_count = 0

def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
"""Restore the Conversation Manager's state from a session.

Args:
state: Previous state of the conversation manager
Returns:
Optional list of messages to prepend to the agents messages. By defualt returns None.
"""
if state.get("__name__") != self.__class__.__name__:
raise ValueError("Invalid conversation manager state.")
self.removed_message_count = state["removed_message_count"]
return None

def get_state(self) -> dict[str, Any]:
"""Get the current state of a Conversation Manager as a Json serializable dictionary."""
return {
"__name__": self.__class__.__name__,
"removed_message_count": self.removed_message_count,
}

@abstractmethod
# pragma: no cover
def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""Applies management strategy to the provided agent.

Expand All @@ -35,7 +66,6 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
pass

@abstractmethod
# pragma: no cover
def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
"""Called when the model's context window is exceeded.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,13 @@
if TYPE_CHECKING:
from ...agent.agent import Agent

from ...types.content import Message, Messages
from ...types.content import Messages
from ...types.exceptions import ContextWindowOverflowException
from .conversation_manager import ConversationManager

logger = logging.getLogger(__name__)


def is_user_message(message: Message) -> bool:
"""Check if a message is from a user.

Args:
message: The message object to check.

Returns:
True if the message has the user role, False otherwise.
"""
return message["role"] == "user"


def is_assistant_message(message: Message) -> bool:
"""Check if a message is from an assistant.

Args:
message: The message object to check.

Returns:
True if the message has the assistant role, False otherwise.
"""
return message["role"] == "assistant"


class SlidingWindowConversationManager(ConversationManager):
"""Implements a sliding window strategy for managing conversation history.

Expand All @@ -52,6 +28,7 @@ def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
Defaults to 40 messages.
should_truncate_results: Truncate tool results when a message is too large for the model's context window
"""
super().__init__()
self.window_size = window_size
self.should_truncate_results = should_truncate_results

Expand Down Expand Up @@ -129,6 +106,9 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs
# If we didn't find a valid trim_index, then we throw
raise ContextWindowOverflowException("Unable to trim conversation context!") from e

# trim_index represents the number of messages being removed from the agents messages array
self.removed_message_count += trim_index

# Overwrite message history
messages[:] = messages[trim_index:]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import TYPE_CHECKING, Any, List, Optional

from typing_extensions import override

from ...types.content import Message
from ...types.exceptions import ContextWindowOverflowException
from .conversation_manager import ConversationManager
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(
summarization_system_prompt: Optional system prompt override for summarization.
If None, uses the default summarization prompt.
"""
super().__init__()
if summarization_agent is not None and summarization_system_prompt is not None:
raise ValueError(
"Cannot provide both summarization_agent and summarization_system_prompt. "
Expand All @@ -77,6 +80,25 @@ def __init__(
self.preserve_recent_messages = preserve_recent_messages
self.summarization_agent = summarization_agent
self.summarization_system_prompt = summarization_system_prompt
self._summary_message: Optional[Message] = None

@override
def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
"""Restores the Summarizing Conversation manager from its previous state in a session.

Args:
state: The previous state of the Summarizing Conversation Manager.

Returns:
Optionally returns the previous conversation summary if it exists.
"""
super().restore_from_session(state)
self._summary_message = state.get("summary_message")
return [self._summary_message] if self._summary_message else None

def get_state(self) -> dict[str, Any]:
"""Returns a dictionary representation of the state for the Summarizing Conversation Manager."""
return {"summary_message": self._summary_message, **super().get_state()}

def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""Apply management strategy to conversation history.
Expand Down Expand Up @@ -128,11 +150,17 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs
messages_to_summarize = agent.messages[:messages_to_summarize_count]
remaining_messages = agent.messages[messages_to_summarize_count:]

# Keep track of the number of messages that have been summarized thus far.
self.removed_message_count += len(messages_to_summarize)
# If there is a summary message, don't count it in the removed_message_count.
if self._summary_message:
self.removed_message_count -= 1

# Generate summary
summary_message = self._generate_summary(messages_to_summarize, agent)
self._summary_message = self._generate_summary(messages_to_summarize, agent)

# Replace the summarized messages with the summary
agent.messages[:] = [summary_message] + remaining_messages
agent.messages[:] = [self._summary_message] + remaining_messages

except Exception as summarization_error:
logger.error("Summarization failed: %s", summarization_error)
Expand Down
4 changes: 2 additions & 2 deletions src/strands/session/file_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
└── agent_<agent_id>/
├── agent.json # Agent metadata
└── messages/
├── message_<created_timestamp>_<id1>.json
└── message_<created_timestamp>_<id2>.json
├── message_<id1>.json
└── message_<id2>.json

"""

Expand Down
22 changes: 19 additions & 3 deletions src/strands/session/repository_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,25 @@ def initialize(self, agent: Agent) -> None:
agent.agent_id,
self.session_id,
)
session_messages = self.session_repository.list_messages(self.session_id, agent.agent_id)
agent.state = AgentState(session_agent.state)

# Restore the conversation manager to its previous state, and get the optional prepend messages
prepend_messsages = agent.conversation_manager.restore_from_session(
session_agent.conversation_manager_state
)

if prepend_messsages is None:
prepend_messsages = []

# List the messages currently in the session, using an offset of the messages previously removed
# by the converstaion manager.
session_messages = self.session_repository.list_messages(
session_id=self.session_id,
agent_id=agent.agent_id,
offset=agent.conversation_manager.removed_message_count,
)
if len(session_messages) > 0:
self._latest_agent_message[agent.agent_id] = session_messages[-1]
agent.messages = [session_message.to_message() for session_message in session_messages]

agent.state = AgentState(session_agent.state)
# Resore the agents messages array including the optional prepend messages
agent.messages = prepend_messsages + [session_message.to_message() for session_message in session_messages]
6 changes: 3 additions & 3 deletions src/strands/session/s3_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class S3SessionManager(RepositorySessionManager, SessionRepository):
└── agent_<agent_id>/
├── agent.json # Agent metadata
└── messages/
├── message_<created_timestamp>_<id1>.json
└── message_<created_timestamp>_<id2>.json
├── message_<id1>.json
└── message_<id2>.json

"""

Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(

def _get_session_path(self, session_id: str) -> str:
"""Get session S3 prefix."""
return f"{self.prefix}{SESSION_PREFIX}{session_id}/"
return f"{self.prefix}/{SESSION_PREFIX}{session_id}/"

def _get_agent_path(self, session_id: str, agent_id: str) -> str:
"""Get agent S3 prefix."""
Expand Down
10 changes: 9 additions & 1 deletion src/strands/session/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

from ..hooks.events import AgentInitializedEvent, MessageAddedEvent
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
from ..hooks.registry import HookProvider, HookRegistry
from ..types.content import Message

Expand All @@ -22,10 +22,18 @@ class SessionManager(HookProvider, ABC):

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register hooks for persisting the agent to the session."""
# After the normal Agent initialization behavior, call the session initialize function to restore the agent
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))

# For each message appended to the Agents messages, store that message in the session
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))

# Sync the agent into the session for each message in case the agent state was updated
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))

# After an agent was invoked, sync it with the session to capture any conversation manager state updates
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))

@abstractmethod
def redact_latest_message(self, redact_message: Message, agent: "Agent") -> None:
"""Redact the message most recently appended to the agent in the session.
Expand Down
2 changes: 2 additions & 0 deletions src/strands/types/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class SessionAgent:

agent_id: str
state: Dict[str, Any]
conversation_manager_state: Dict[str, Any]
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())

Expand All @@ -116,6 +117,7 @@ def from_agent(cls, agent: Agent) -> "SessionAgent":
raise ValueError("agent_id needs to be defined.")
return cls(
agent_id=agent.agent_id,
conversation_manager_state=agent.conversation_manager.get_state(),
state=agent.state.get(),
)

Expand Down
20 changes: 10 additions & 10 deletions tests/fixtures/mock_session_repository.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from strands.session.session_repository import SessionRepository
from strands.types.exceptions import SessionException
from strands.types.session import SessionAgent, SessionMessage


class MockedSessionRepository(SessionRepository):
Expand All @@ -11,21 +12,20 @@ def __init__(self):
self.agents = {}
self.messages = {}

def create_session(self, session):
def create_session(self, session) -> None:
"""Create a session."""
session_id = session.session_id
if session_id in self.sessions:
raise SessionException(f"Session {session_id} already exists")
self.sessions[session_id] = session
self.agents[session_id] = {}
self.messages[session_id] = {}
return session

def read_session(self, session_id):
def read_session(self, session_id) -> SessionAgent:
"""Read a session."""
return self.sessions.get(session_id)

def create_agent(self, session_id, session_agent):
def create_agent(self, session_id, session_agent) -> None:
"""Create an agent."""
agent_id = session_agent.agent_id
if session_id not in self.sessions:
Expand All @@ -36,13 +36,13 @@ def create_agent(self, session_id, session_agent):
self.messages.setdefault(session_id, {}).setdefault(agent_id, {})
return session_agent

def read_agent(self, session_id, agent_id):
def read_agent(self, session_id, agent_id) -> SessionAgent:
"""Read an agent."""
if session_id not in self.sessions:
return None
return self.agents.get(session_id, {}).get(agent_id)

def update_agent(self, session_id, session_agent):
def update_agent(self, session_id, session_agent) -> None:
"""Update an agent."""
agent_id = session_agent.agent_id
if session_id not in self.sessions:
Expand All @@ -51,7 +51,7 @@ def update_agent(self, session_id, session_agent):
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
self.agents[session_id][agent_id] = session_agent

def create_message(self, session_id, agent_id, session_message):
def create_message(self, session_id, agent_id, session_message) -> None:
"""Create a message."""
message_id = session_message.message_id
if session_id not in self.sessions:
Expand All @@ -62,15 +62,15 @@ def create_message(self, session_id, agent_id, session_message):
raise SessionException(f"Message {message_id} already exists in agent {agent_id} in session {session_id}")
self.messages.setdefault(session_id, {}).setdefault(agent_id, {})[message_id] = session_message

def read_message(self, session_id, agent_id, message_id):
def read_message(self, session_id, agent_id, message_id) -> SessionMessage:
"""Read a message."""
if session_id not in self.sessions:
return None
if agent_id not in self.agents.get(session_id, {}):
return None
return self.messages.get(session_id, {}).get(agent_id, {}).get(message_id)

def update_message(self, session_id, agent_id, session_message):
def update_message(self, session_id, agent_id, session_message) -> None:
"""Update a message."""

message_id = session_message.message_id
Expand All @@ -82,7 +82,7 @@ def update_message(self, session_id, agent_id, session_message):
raise SessionException(f"Message {message_id} does not exist in session {session_id}")
self.messages[session_id][agent_id][message_id] = session_message

def list_messages(self, session_id, agent_id, limit=None, offset=0):
def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[SessionMessage]:
"""List messages."""
if session_id not in self.sessions:
return []
Expand Down
Loading
Loading