diff --git a/pyproject.toml b/pyproject.toml index 8bb55bd65..ee834b8c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,7 +156,7 @@ test-lint = [ "hatch fmt --linter --check" ] test = [ - "hatch test --cover --cov-report html --cov-report xml {args}" + "hatch test --cover --cov-report term-missing --cov-report html --cov-report xml {args}" ] test-integ = [ "hatch test tests-integ {args}" diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 4b11e81ce..542f624ff 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,12 +1,10 @@ """Sliding window conversation history management.""" -import json import logging -from typing import List, Optional, cast +from typing import List, Optional -from ...types.content import ContentBlock, Message, Messages +from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException -from ...types.tools import ToolResult from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) @@ -36,6 +34,34 @@ def is_assistant_message(message: Message) -> bool: return message["role"] == "assistant" +def has_tool_use(message: Message) -> bool: + """Check if a message contains toolUse content.""" + return any("toolUse" in content for content in message["content"]) + + +def has_tool_result(message: Message) -> bool: + """Check if a message contains toolResult content.""" + return any("toolResult" in content for content in message["content"]) + + +def get_tool_use_ids(message: Message) -> List[str]: + """Get all toolUse IDs from a message.""" + ids = [] + for content in message["content"]: + if "toolUse" in content: + ids.append(content["toolUse"]["toolUseId"]) + return ids + + +def get_tool_result_ids(message: Message) -> List[str]: + """Get all toolResult IDs from a message.""" + ids = [] + for content in message["content"]: + if "toolResult" in content: + ids.append(content["toolResult"]["toolUseId"]) + return ids + + class SlidingWindowConversationManager(ConversationManager): """Implements a sliding window strategy for managing conversation history. @@ -95,23 +121,23 @@ def _remove_dangling_messages(self, messages: Messages) -> None: """ # remove any dangling user messages with no ToolResult if len(messages) > 0 and is_user_message(messages[-1]): - if not any("toolResult" in content for content in messages[-1]["content"]): + if not has_tool_result(messages[-1]): messages.pop() # remove any dangling assistant messages with ToolUse if len(messages) > 0 and is_assistant_message(messages[-1]): - if any("toolUse" in content for content in messages[-1]["content"]): + if has_tool_use(messages[-1]): messages.pop() # remove remaining dangling user messages with no ToolResult after we popped off an assistant message if len(messages) > 0 and is_user_message(messages[-1]): - if not any("toolResult" in content for content in messages[-1]["content"]): + if not has_tool_result(messages[-1]): messages.pop() def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. - The method handles special cases where tool results need to be converted to regular content blocks to maintain - conversation coherence after trimming. + The method ensures that tool use/result pairs are preserved together. If a cut would separate + a toolUse from its corresponding toolResult, it adjusts the cut point to include both. Args: messages: The messages to reduce. @@ -120,58 +146,66 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N Raises: ContextWindowOverflowException: If the context cannot be reduced further. - Such as when the conversation is already minimal or when tool result messages cannot be properly - converted. """ - # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size + # Calculate basic trim index trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size # Throw if we cannot trim any messages from the conversation if trim_index >= len(messages): raise ContextWindowOverflowException("Unable to trim conversation context!") from e - # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the - # limitation of needing ToolUse and ToolResults to be paired. - if any("toolResult" in content for content in messages[trim_index]["content"]): - if len(messages[trim_index]["content"]) == 1: - messages[trim_index]["content"] = self._map_tool_result_content( - cast(ToolResult, messages[trim_index]["content"][0]["toolResult"]) - ) + # Find a safe cutting point that preserves tool use/result pairs + safe_trim_index = self._find_safe_trim_index(messages, trim_index) - # If there is more content than just one ToolResultContent, then we cannot cut at this index. - else: - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + # If we couldn't find a safe trim point within bounds, fall back to basic trim + if safe_trim_index >= len(messages): + logger.warning( + "safe_trim_index=<%d>, messages_length=<%d> | could not find safe trim point | " + "falling back to basic trim index", + safe_trim_index, + len(messages), + ) + safe_trim_index = trim_index # Overwrite message history - messages[:] = messages[trim_index:] + messages[:] = messages[safe_trim_index:] - def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]: - """Convert a ToolResult to a list of standard ContentBlocks. + def _find_safe_trim_index(self, messages: Messages, initial_trim_index: int) -> int: + """Find a safe cutting point that preserves tool use/result pairs. - This method transforms tool result content into standard content blocks that can be preserved when trimming the - conversation history. + This method ensures that tool use/result pairs are not separated by the trim. + It adjusts the trim index to keep related tool interactions together. Args: - tool_result: The ToolResult to convert. + messages: The complete message history + initial_trim_index: The initial trim index based on window size Returns: - A list of content blocks representing the tool result. + A safe trim index that preserves tool use/result pairs """ - contents = [] - text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else "" - - for tool_result_content in tool_result["content"]: - if "text" in tool_result_content: - text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}" - elif "json" in tool_result_content: - text_content = ( - "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}" - ) - elif "image" in tool_result_content: - contents.append(ContentBlock(image=tool_result_content["image"])) - elif "document" in tool_result_content: - contents.append(ContentBlock(document=tool_result_content["document"])) - else: - logger.warning("unsupported content type") - contents.append(ContentBlock(text=text_content)) - return contents + # Build a map of tool IDs to their message indices + tool_use_indices = {} # toolUseId -> message index + tool_result_indices = {} # toolUseId -> message index + + for i, message in enumerate(messages): + for tool_id in get_tool_use_ids(message): + tool_use_indices[tool_id] = i + for tool_id in get_tool_result_ids(message): + tool_result_indices[tool_id] = i + + # Start from the initial trim index + safe_index = initial_trim_index + + # Adjust if we would cut in the middle of a tool use/result pair + for tool_id, use_idx in tool_use_indices.items(): + if tool_id in tool_result_indices: + result_idx = tool_result_indices[tool_id] + # If the pair would be split by the cut + if use_idx < safe_index <= result_idx: + # Move the cut to before the tool use to keep the pair together + safe_index = min(safe_index, use_idx) + elif result_idx < safe_index < use_idx: + # This shouldn't happen in valid conversations + logger.warning("tool_id=<%s> | found toolResult before toolUse", tool_id) + + return safe_index diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 2f6ee77de..46ec84d79 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,6 +1,12 @@ import pytest import strands +from strands.agent.conversation_manager.sliding_window_conversation_manager import ( + get_tool_result_ids, + get_tool_use_ids, + has_tool_result, + has_tool_use, +) from strands.types.exceptions import ContextWindowOverflowException @@ -127,7 +133,9 @@ def conversation_manager(request): [ { "role": "user", - "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}], + "content": [ + {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}} + ], }, ], ), @@ -142,7 +150,7 @@ def conversation_manager(request): {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, ], [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, ], ), # 9 - Message count above max window size - Preserve tool use/tool result pairs @@ -174,6 +182,7 @@ def conversation_manager(request): ], ), # 11 - Test sliding window with multiple tool pairs that need preservation + # Note: The manager prioritizes keeping tool pairs together, which may exceed window size ( {"window_size": 4}, [ @@ -184,8 +193,10 @@ def conversation_manager(request): {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, ], + # The manager keeps tool pairs together, resulting in 5 messages to preserve the first pair [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, @@ -232,3 +243,314 @@ def test_null_conversation_manager_reduce_context_with_exception_raises_same_exc manager.reduce_context(messages, RuntimeError("test")) assert messages == original_messages + + +def test_has_tool_use(): + """Test has_tool_use helper function.""" + from strands.agent.conversation_manager.sliding_window_conversation_manager import has_tool_use + + message_with_tool_use = { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "test", "input": {}}}], + } + message_without_tool_use = {"role": "assistant", "content": [{"text": "Hello"}]} + assert has_tool_use(message_with_tool_use) is True + assert has_tool_use(message_without_tool_use) is False + + +def test_has_tool_result(): + """Test has_tool_result helper function.""" + from strands.agent.conversation_manager.sliding_window_conversation_manager import has_tool_result + + message_with_tool_result = { + "role": "user", + "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}], + } + message_without_tool_result = {"role": "user", "content": [{"text": "Hello"}]} + assert has_tool_result(message_with_tool_result) is True + assert has_tool_result(message_without_tool_result) is False + + +def test_get_tool_use_ids(): + """Test get_tool_use_ids helper function.""" + from strands.agent.conversation_manager.sliding_window_conversation_manager import get_tool_use_ids + + message_with_multiple_tool_uses = { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "123", "name": "test1", "input": {}}}, + {"toolUse": {"toolUseId": "456", "name": "test2", "input": {}}}, + {"text": "Some text"}, + ], + } + message_without_tool_use = {"role": "assistant", "content": [{"text": "Hello"}]} + assert get_tool_use_ids(message_with_multiple_tool_uses) == ["123", "456"] + assert get_tool_use_ids(message_without_tool_use) == [] + + +def test_get_tool_result_ids(): + """Test get_tool_result_ids helper function.""" + from strands.agent.conversation_manager.sliding_window_conversation_manager import get_tool_result_ids + + message_with_multiple_tool_results = { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}, + {"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}, + {"text": "Some text"}, + ], + } + message_without_tool_result = {"role": "user", "content": [{"text": "Hello"}]} + assert get_tool_result_ids(message_with_multiple_tool_results) == ["123", "456"] + assert get_tool_result_ids(message_without_tool_result) == [] + + +def test_reduce_context_edge_cases(): + """Test edge cases in reduce_context method.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=2) + + # Test case 1: Messages with multiple content items including toolResult + # Preserves tool results as-is + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + { + "role": "user", + "content": [ + {"text": "Multiple items"}, + {"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}, + ], + }, + {"role": "assistant", "content": [{"text": "Response"}]}, + ] + + manager.reduce_context(messages) + + # Should keep the last 2 messages + assert len(messages) == 2 + assert messages[0]["content"][0]["text"] == "Multiple items" + assert "toolResult" in messages[0]["content"][1] + + # Test case 2: Extreme case - trim_index >= len(messages) + # Create a scenario where window_size is 2 but we need to trim from messages with length 1 + messages = [ + {"role": "user", "content": [{"text": "Only message"}]}, + ] + + # Force a reduce context when trim_index would be >= len(messages) + # Since len(messages) = 1 and window_size = 2, trim_index = 2 + with pytest.raises(ContextWindowOverflowException, match="Unable to trim conversation context!"): + manager.reduce_context(messages) + + +def test_find_safe_trim_index_orphaned_results(): + """Test _find_safe_trim_index with orphaned tool results.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=3) + + # Create messages with orphaned tool results (no matching toolUse) + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + # Orphaned tool result - no matching toolUse + {"role": "user", "content": [{"toolResult": {"toolUseId": "orphan1", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"text": "Another response"}]}, + # Another orphaned tool result + {"role": "user", "content": [{"toolResult": {"toolUseId": "orphan2", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"text": "Final response"}]}, + ] + + # Basic trim index would be 3 (6 messages - 3 window size) + initial_trim_index = 3 + safe_index = manager._find_safe_trim_index(messages, initial_trim_index) + + # The orphaned tool result at index 2 should be OK to cut at + # But the method should prefer index 3 (assistant message without tool use) + assert safe_index == 3 + + +def test_find_safe_trim_index_tool_result_before_use(): + """Test warning case where tool result appears before tool use.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=3) + + # Create an invalid scenario (shouldn't happen in practice) + messages = [ + {"role": "user", "content": [{"text": "Start"}]}, + # Tool result before tool use (invalid) + {"role": "user", "content": [{"toolResult": {"toolUseId": "backwards", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "backwards", "name": "tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "End"}]}, + ] + + # This should trigger the warning in _find_safe_trim_index + initial_trim_index = 1 + safe_index = manager._find_safe_trim_index(messages, initial_trim_index) + # Should still return a valid index + assert safe_index >= 0 + + +def test_find_safe_trim_index_extreme_no_good_cut(): + """Test _find_safe_trim_index when initial trim is beyond messages.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=2) + + # Create a scenario where initial trim index is already beyond messages length + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1", "content": [], "status": "success"}}]}, + ] + + # Initial trim index is 10 (way beyond the 2 messages) + initial_trim_index = 10 + safe_index = manager._find_safe_trim_index(messages, initial_trim_index) + + # Should still be beyond messages length since there's no good cut + assert safe_index >= len(messages) + + +def test_reduce_context_with_orphaned_tool_result_at_start(): + """Test reduce_context when the safe trim point starts with an orphaned tool result.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=2) + + # Create messages where the trim will start with an orphaned tool result + messages = [ + {"role": "user", "content": [{"text": "First"}]}, + {"role": "assistant", "content": [{"text": "Response"}]}, + # This will be the first message after trim - it's an orphaned tool result + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "orphan", "content": [{"text": "Result text"}], "status": "success"}} + ], + }, + {"role": "assistant", "content": [{"text": "Final"}]}, + ] + + # Reduce context - preserves the toolResult + manager.reduce_context(messages) + + # Check that the messages are trimmed correctly + assert len(messages) == 2 + assert messages[0]["role"] == "user" + # The toolResult should be preserved + assert "toolResult" in messages[0]["content"][0] + assert messages[0]["content"][0]["toolResult"]["toolUseId"] == "orphan" + + +def test_reduce_context_safe_trim_beyond_messages(): + """Test reduce_context when it preserves tool use/result pairs.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=2) + + # Create a scenario where there's a tool use/result pair plus an orphaned result + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1", "content": [], "status": "success"}}]}, + # This is an orphaned tool result (no matching toolUse) + { + "role": "user", + "content": [ + {"toolResult": {"toolUseId": "orphan", "content": [{"text": "orphaned"}], "status": "success"}} + ], + }, + ] + + # Reduce context - should preserve tool use/result pairs + manager.reduce_context(messages) + + # Should keep the tool use/result pair to maintain integrity + assert len(messages) >= 2 + # The pair should be preserved together + if len(messages) >= 2: + assert has_tool_use(messages[0]) + assert has_tool_result(messages[1]) + assert get_tool_use_ids(messages[0])[0] == get_tool_result_ids(messages[1])[0] + + +def test_find_safe_trim_index_fallback_to_basic_trim(): + """Test that we correctly handle the case where any trim breaks a tool pair.""" + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=2) + + # All messages are tool pairs - any trim breaks a pair + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "2", "name": "tool2", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2", "content": [], "status": "success"}}]}, + ] + + # The basic trim index would be 2 (window_size=2, len=4) + # This cuts between the two tool pairs, which is a valid cut point + manager.reduce_context(messages) + + # Window size is respected - only 2 messages remain + assert len(messages) == 2 + + # First tool pair was removed, second pair preserved + assert messages[0]["content"][0]["toolUse"]["toolUseId"] == "2" + assert messages[1]["content"][0]["toolResult"]["toolUseId"] == "2" + + +def test_find_safe_trim_index_warning_scenario(): + """Test the warning scenario where safe_index >= len(messages).""" + from unittest.mock import patch + + # Create a custom manager with a specific _find_safe_trim_index implementation + # that returns an index >= len(messages) to trigger the warning + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=2) + + # Create a message list longer than window size + messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + {"role": "user", "content": [{"text": "How are you?"}]}, + {"role": "assistant", "content": [{"text": "I'm good"}]}, + ] + + # Mock _find_safe_trim_index to return a value >= len(messages) + with patch.object(manager, "_find_safe_trim_index", return_value=5): + with patch("strands.agent.conversation_manager.sliding_window_conversation_manager.logger") as mock_logger: + # This should trigger the warning and use basic trim + manager.reduce_context(messages) + + # Warning should be logged + mock_logger.warning.assert_called_with( + "safe_trim_index=<%d>, messages_length=<%d> | could not find safe trim point | " + "falling back to basic trim index", + 5, + 4, + ) + + # Should use basic trim index (2) instead of the invalid index (5) + # So 2 messages remain (last 2 based on window size) + assert len(messages) == 2 + assert messages[0]["content"][0]["text"] == "How are you?" + assert messages[1]["content"][0]["text"] == "I'm good" + + +def test_find_safe_trim_index_tool_result_before_use_with_warning(): + """Test that warning is logged when tool result appears before tool use.""" + from unittest.mock import patch + + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(window_size=3) + + # Create an invalid scenario where tool result comes before tool use + # and the safe_index will fall between them + messages = [ + {"role": "user", "content": [{"text": "Start"}]}, + # Tool result at index 1 (before its corresponding tool use) + {"role": "user", "content": [{"toolResult": {"toolUseId": "backwards", "content": [], "status": "success"}}]}, + {"role": "assistant", "content": [{"text": "Middle message"}]}, # Index 2 - this will be our safe_index + # Tool use at index 3 (after its corresponding tool result) + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "backwards", "name": "tool", "input": {}}}]}, + {"role": "user", "content": [{"text": "End"}]}, + ] + + # Initial trim index would be 2 (5 messages - 3 window size) + # This will make safe_index = 2, which is between result_idx (1) and use_idx (3) + initial_trim_index = 2 + + with patch("strands.agent.conversation_manager.sliding_window_conversation_manager.logger") as mock_logger: + safe_index = manager._find_safe_trim_index(messages, initial_trim_index) + + # Verify the warning was logged with the correct message + mock_logger.warning.assert_called_with("tool_id=<%s> | found toolResult before toolUse", "backwards") + + # The method should still return a valid index + assert safe_index == 2 # Should use the initial trim index