Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def __init__(
if self._session_manager:
self.hooks.add_hook(self._session_manager)

# Check if conversation_manager implements HookProvider protocol
if isinstance(self.conversation_manager, HookProvider):
self.hooks.add_hook(self.conversation_manager)

self.tool_executor = tool_executor or ConcurrentToolExecutor()

if hooks:
Expand Down
16 changes: 16 additions & 0 deletions src/strands/agent/conversation_manager/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

if TYPE_CHECKING:
from ...agent.agent import Agent
from ...hooks.registry import HookRegistry


class ConversationManager(ABC):
Expand All @@ -18,6 +19,9 @@ class ConversationManager(ABC):
- Manage memory usage
- Control context length
- Maintain relevant conversation state

Conversation managers can optionally implement the HookProvider protocol by overriding
register_hooks() to integrate with the agent's event lifecycle.
"""

def __init__(self) -> None:
Expand All @@ -30,6 +34,18 @@ def __init__(self) -> None:
"""
self.removed_message_count = 0

def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""Register hook callbacks for lifecycle event integration.

By default, this is a no-op. Subclasses can override this method to register
callbacks for agent lifecycle events (e.g., BeforeModelCallEvent, AfterToolCallEvent).

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
pass

def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
"""Restore the Conversation Manager's state from a session.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

if TYPE_CHECKING:
from ...agent.agent import Agent
from ...hooks.registry import HookRegistry

from ...hooks.events import BeforeModelCallEvent
from ...types.content import Messages
from ...types.exceptions import ContextWindowOverflowException
from .conversation_manager import ConversationManager
Expand All @@ -18,19 +20,107 @@ class SlidingWindowConversationManager(ConversationManager):

This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids
invalid window states.

Supports proactive management during agent loop execution via the per_turn parameter.
"""

def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
def __init__(self, window_size: int = 40, should_truncate_results: bool = True, per_turn: bool | int = False):
"""Initialize the sliding window conversation manager.

Args:
window_size: Maximum number of messages to keep in the agent's history.
Defaults to 40 messages.
should_truncate_results: Truncate tool results when a message is too large for the model's context window
per_turn: Controls when to apply message management during agent execution.
- False (default): Only apply management at the end
- True: Apply management before every model call
- int (e.g., 3): Apply management before every N model calls

Reduces context size to speed up execution and reduce costs, but removes
conversation history the agent uses for decision-making.

Raises:
ValueError: If per_turn is 0 or a negative integer.
"""
super().__init__()

# Validate per_turn parameter
# Note: Must check bool before int since bool is a subclass of int in Python
if not isinstance(per_turn, bool) and isinstance(per_turn, int) and per_turn <= 0:
raise ValueError(f"per_turn must be True, False, or a positive integer, got {per_turn}")

self.window_size = window_size
self.should_truncate_results = should_truncate_results
self.per_turn = per_turn
self.model_call_count = 0

def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""Register hook callbacks for per-turn conversation management.

Overrides the base ConversationManager.register_hooks() to enable per-turn management
during agent execution. When per_turn is enabled, registers a callback for
BeforeModelCallEvent to apply message management during the agent loop.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
# Always register the callback - per_turn check happens in the callback
registry.add_callback(BeforeModelCallEvent, self._on_before_model_call)

def _on_before_model_call(self, event: BeforeModelCallEvent) -> None:
"""Handle before model call event for per-turn management.

This callback is invoked before each model call when per_turn is enabled.
It tracks the model call count and applies message management based on the
per_turn configuration.

Args:
event: The before model call event containing the agent and model execution details.
"""
# Check if per_turn is enabled
if self.per_turn is False:
return

self.model_call_count += 1

# Determine if we should apply management
should_apply = False
if self.per_turn is True:
should_apply = True
elif isinstance(self.per_turn, int) and self.per_turn > 0:
should_apply = self.model_call_count % self.per_turn == 0

if should_apply:
logger.debug(
"model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management",
self.model_call_count,
self.per_turn,
)
self.apply_management(event.agent)

def get_state(self) -> dict[str, Any]:
"""Get the current state of the conversation manager.

Returns:
Dictionary containing the manager's state, including model call count for per-turn tracking.
"""
state = super().get_state()
state["model_call_count"] = self.model_call_count
return state

def restore_from_session(self, state: dict[str, Any]) -> Optional[list]:
"""Restore the conversation manager's state from a session.

Args:
state: Previous state of the conversation manager

Returns:
Optional list of messages to prepend to the agent's messages.
"""
result = super().restore_from_session(state)
self.model_call_count = state.get("model_call_count", 0)
return result

def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.
Expand Down
3 changes: 2 additions & 1 deletion src/strands/hooks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import inspect
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable

from ..interrupt import Interrupt, InterruptException

Expand Down Expand Up @@ -84,6 +84,7 @@ class HookEvent(BaseHookEvent):
"""Generic for invoking events - non-contravariant to enable returning events."""


@runtime_checkable
class HookProvider(Protocol):
"""Protocol for objects that provide hook callbacks to an agent.

Expand Down
89 changes: 89 additions & 0 deletions tests/strands/agent/test_conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from strands.agent.agent import Agent
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.hooks.events import BeforeModelCallEvent
from strands.hooks.registry import HookRegistry
from strands.types.exceptions import ContextWindowOverflowException
from tests.fixtures.mocked_model_provider import MockedModelProvider


@pytest.fixture
Expand Down Expand Up @@ -246,3 +249,89 @@ def test_null_conversation_does_not_restore_with_incorrect_state():

with pytest.raises(ValueError):
manager.restore_from_session({})


# ==============================================================================
# Per-Turn Management Tests
# ==============================================================================


@pytest.mark.parametrize(
"per_turn_value,is_valid",
[
(False, True),
(True, True),
(1, True),
(3, True),
(0, False),
(-1, False),
],
)
def test_per_turn_parameter_validation(per_turn_value, is_valid):
"""Test per_turn parameter validation for various values."""
if is_valid:
manager = SlidingWindowConversationManager(per_turn=per_turn_value)
assert manager.per_turn == per_turn_value
else:
with pytest.raises(ValueError):
SlidingWindowConversationManager(per_turn=per_turn_value)


@pytest.mark.parametrize(
"per_turn,num_responses,expected_min_calls",
[
(False, 1, 1), # Only finally block
(True, 1, 2), # Before model call + finally
(2, 1, 1), # Not reached, only finally
(3, 1, 1), # Not reached, only finally
],
)
def test_per_turn_management_frequency(per_turn, num_responses, expected_min_calls):
"""Test apply_management call frequency with different per_turn values."""
from unittest.mock import patch

manager = SlidingWindowConversationManager(per_turn=per_turn, window_size=100)
responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(num_responses)]
model = MockedModelProvider(responses)
agent = Agent(model=model, conversation_manager=manager)

with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
agent("Test")
assert mock.call_count >= expected_min_calls


def test_per_turn_dynamic_change():
"""Test that per_turn can be changed dynamically during execution."""
from unittest.mock import MagicMock, patch

manager = SlidingWindowConversationManager(per_turn=False)
registry = HookRegistry()
manager.register_hooks(registry)

mock_agent = MagicMock()
mock_agent.messages = []
event = BeforeModelCallEvent(agent=mock_agent)

# Initially disabled
with patch.object(manager, "apply_management") as mock_apply:
registry.invoke_callbacks(event)
assert mock_apply.call_count == 0

# Enable dynamically
manager.per_turn = True
with patch.object(manager, "apply_management") as mock_apply:
registry.invoke_callbacks(event)
assert mock_apply.call_count == 1


def test_per_turn_state_persistence():
"""Test that model_call_count is persisted and restored from state."""
manager = SlidingWindowConversationManager(per_turn=3)
manager.model_call_count = 7

state = manager.get_state()
assert state["model_call_count"] == 7

new_manager = SlidingWindowConversationManager(per_turn=3)
new_manager.restore_from_session(state)
assert new_manager.model_call_count == 7