diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 6a9489809..bed9e52e6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -457,27 +457,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str Returns: The result of the event loop cycle. """ - kwargs.pop("agent", None) - kwargs.pop("model", None) - kwargs.pop("system_prompt", None) - kwargs.pop("tool_execution_handler", None) - kwargs.pop("event_loop_metrics", None) - kwargs.pop("callback_handler", None) - kwargs.pop("tool_handler", None) - kwargs.pop("messages", None) - kwargs.pop("tool_config", None) + # Extract parameters with fallbacks to instance values + system_prompt = kwargs.pop("system_prompt", self.system_prompt) + model = kwargs.pop("model", self.model) + tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper) + event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics) + callback_handler_override = kwargs.pop("callback_handler", callback_handler) + tool_handler = kwargs.pop("tool_handler", self.tool_handler) + messages = kwargs.pop("messages", self.messages) + tool_config = kwargs.pop("tool_config", self.tool_config) + kwargs.pop("agent", None) # Remove agent to avoid conflicts try: # Execute the main event loop cycle stop_reason, message, metrics, state = event_loop_cycle( - model=self.model, - system_prompt=self.system_prompt, - messages=self.messages, # will be modified by event_loop_cycle - tool_config=self.tool_config, - callback_handler=callback_handler, - tool_handler=self.tool_handler, - tool_execution_handler=self.thread_pool_wrapper, - event_loop_metrics=self.event_loop_metrics, + model=model, + system_prompt=system_prompt, + messages=messages, # will be modified by event_loop_cycle + tool_config=tool_config, + callback_handler=callback_handler_override, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + event_loop_metrics=event_loop_metrics, agent=self, event_loop_parent_span=self.trace_span, **kwargs, @@ -488,8 +489,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self.messages, e=e) - return self._execute_event_loop_cycle(callback_handler, kwargs) + self.conversation_manager.reduce_context(messages, e=e) + return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( self, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5c7d11e46..ff70089bd 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -337,17 +337,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, ], ] + override_system_prompt = "Override system prompt" + override_model = unittest.mock.Mock() + override_tool_execution_handler = unittest.mock.Mock() + override_event_loop_metrics = unittest.mock.Mock() + override_callback_handler = unittest.mock.Mock() + override_tool_handler = unittest.mock.Mock() + override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] + override_tool_config = {"test": "config"} + def check_kwargs(some_value, **kwargs): assert some_value == "a_value" assert kwargs is not None + assert kwargs["system_prompt"] == override_system_prompt + assert kwargs["model"] == override_model + assert kwargs["tool_execution_handler"] == override_tool_execution_handler + assert kwargs["event_loop_metrics"] == override_event_loop_metrics + assert kwargs["callback_handler"] == override_callback_handler + assert kwargs["tool_handler"] == override_tool_handler + assert kwargs["messages"] == override_messages + assert kwargs["tool_config"] == override_tool_config + assert kwargs["agent"] == agent # Return expected values from event_loop_cycle return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {} mock_event_loop_cycle.side_effect = check_kwargs - agent("test message", some_value="a_value") - assert mock_event_loop_cycle.call_count == 1 + agent( + "test message", + some_value="a_value", + system_prompt=override_system_prompt, + model=override_model, + tool_execution_handler=override_tool_execution_handler, + event_loop_metrics=override_event_loop_metrics, + callback_handler=override_callback_handler, + tool_handler=override_tool_handler, + messages=override_messages, + tool_config=override_tool_config, + ) + + mock_event_loop_cycle.assert_called_once() def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):