diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 28d877948..5502565ad 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -31,7 +31,7 @@ from ..types.content import ContentBlock, Message, Messages from ..types.exceptions import ContextWindowOverflowException from ..types.models import Model -from ..types.tools import ToolConfig, ToolResult, ToolUse +from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult from .conversation_manager import ( @@ -335,15 +335,6 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) - @property - def tool_config(self) -> ToolConfig: - """Get the tool configuration for this agent. - - Returns: - The complete tool configuration. - """ - return self.tool_registry.initialize_tool_config() - def __del__(self) -> None: """Clean up resources when Agent is garbage collected. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 6375b5d82..10c21a00b 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -11,7 +11,7 @@ import logging import time import uuid -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent from ..experimental.hooks.registry import get_registry @@ -21,7 +21,7 @@ from ..types.content import Message from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException from ..types.streaming import Metrics, StopReason -from ..types.tools import ToolGenerator, ToolResult, ToolUse +from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse from .message_processor import clean_orphaned_empty_tool_uses from .streaming import stream_messages @@ -112,10 +112,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener model_id=model_id, ) + tool_specs = agent.tool_registry.get_all_tool_specs() + try: # TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding # to the callback handler. This will be revisited when migrating to strongly typed events. - async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, agent.tool_config): + async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs): if "callback" in event: yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}} @@ -172,12 +174,6 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener # If the model is requesting to use tools if stop_reason == "tool_use": - if agent.tool_config is None: - raise EventLoopException( - Exception("Model requested tool use but no tool config provided"), - kwargs["request_state"], - ) - # Handle tool execution events = _handle_tool_execution( stop_reason, @@ -285,7 +281,10 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG "model": agent.model, "system_prompt": agent.system_prompt, "messages": agent.messages, - "tool_config": agent.tool_config, + "tool_config": ToolConfig( # for backwards compatability + tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()], + toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), + ), } ) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 6ecc3e270..777c3a064 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -19,7 +19,7 @@ StreamEvent, Usage, ) -from ..types.tools import ToolConfig, ToolUse +from ..types.tools import ToolSpec, ToolUse logger = logging.getLogger(__name__) @@ -304,7 +304,7 @@ async def stream_messages( model: Model, system_prompt: Optional[str], messages: Messages, - tool_config: Optional[ToolConfig], + tool_specs: list[ToolSpec], ) -> AsyncGenerator[dict[str, Any], None]: """Streams messages to the model and processes the response. @@ -312,7 +312,7 @@ async def stream_messages( model: Model provider. system_prompt: The system prompt to send. messages: List of messages to send. - tool_config: Configuration for the tools to use. + tool_specs: The list of tool specs. Returns: The reason for stopping, the final message, and the usage metrics @@ -320,8 +320,7 @@ async def stream_messages( logger.debug("model=<%s> | streaming messages", model) messages = remove_blank_messages_content_text(messages) - tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None - chunks = model.converse(messages, tool_specs, system_prompt) + chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt) async for event in process_stream(chunks, messages): yield event diff --git a/src/strands/tools/registry.py b/src/strands/tools/registry.py index 617f77cc1..b0d84946d 100644 --- a/src/strands/tools/registry.py +++ b/src/strands/tools/registry.py @@ -17,7 +17,7 @@ from strands.tools.decorator import DecoratedFunctionTool -from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec +from ..types.tools import AgentTool, ToolSpec from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec logger = logging.getLogger(__name__) @@ -472,20 +472,15 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None: for tool_name, error in tool_import_errors.items(): logger.debug("tool_name=<%s> | import error | %s", tool_name, error) - def initialize_tool_config(self) -> ToolConfig: - """Initialize tool configuration from tool handler with optional filtering. + def get_all_tool_specs(self) -> list[ToolSpec]: + """Get all the tool specs for all tools in this registry.. Returns: - Tool config. + A list of ToolSpecs. """ all_tools = self.get_all_tools_config() - - tools: List[Tool] = [{"toolSpec": tool_spec} for tool_spec in all_tools.values()] - - return ToolConfig( - tools=tools, - toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}), - ) + tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()] + return tools def validate_tool_spec(self, tool_spec: ToolSpec) -> None: """Validate tool specification against required schema. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 82283490f..83cb7ed77 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -180,7 +180,7 @@ def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_impor agent = Agent(tools=[tool_decorated, tool_module, tool_imported]) - tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"]) + tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs()) exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"] assert tru_tool_names == exp_tool_names @@ -191,7 +191,7 @@ def test_agent__init__tool_loader_dict(tool_module, tool_registry): agent = Agent(tools=[{"name": "tool_module", "path": tool_module}]) - tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"]) + tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs()) exp_tool_names = ["tool_module"] assert tru_tool_names == exp_tool_names diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 700608c2e..d7de187d7 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -35,11 +35,6 @@ def messages(): return [{"role": "user", "content": [{"text": "Hello"}]}] -@pytest.fixture -def tool_config(): - return {"tools": [{"toolSpec": {"name": "tool_for_testing"}}], "toolChoice": {"auto": {}}} - - @pytest.fixture def tool_registry(): return ToolRegistry() @@ -116,13 +111,12 @@ def hook_provider(hook_registry): @pytest.fixture -def agent(model, system_prompt, messages, tool_config, tool_registry, thread_pool, hook_registry): +def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry): mock = unittest.mock.Mock(name="agent") mock.config.cache_points = [] mock.model = model mock.system_prompt = system_prompt mock.messages = messages - mock.tool_config = tool_config mock.tool_registry = tool_registry mock.thread_pool = thread_pool mock.event_loop_metrics = EventLoopMetrics() @@ -298,6 +292,7 @@ async def test_event_loop_cycle_tool_result( system_prompt, messages, tool_stream, + tool_registry, agenerator, alist, ): @@ -353,7 +348,7 @@ async def test_event_loop_cycle_tool_result( }, {"role": "assistant", "content": [{"text": "test text"}]}, ], - [{"name": "tool_for_testing"}], + tool_registry.get_all_tool_specs(), "p1", ) diff --git a/tests/strands/event_loop/test_streaming.py b/tests/strands/event_loop/test_streaming.py index 7b64264e3..44c5b5a8e 100644 --- a/tests/strands/event_loop/test_streaming.py +++ b/tests/strands/event_loop/test_streaming.py @@ -549,7 +549,7 @@ async def test_stream_messages(agenerator, alist): mock_model, system_prompt="test prompt", messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}], - tool_config=None, + tool_specs=None, ) tru_events = await alist(stream) diff --git a/tests/strands/tools/test_registry.py b/tests/strands/tools/test_registry.py index 4d92be0c3..ebcba3fb1 100644 --- a/tests/strands/tools/test_registry.py +++ b/tests/strands/tools/test_registry.py @@ -6,6 +6,7 @@ import pytest +import strands from strands.tools import PythonAgentTool from strands.tools.decorator import DecoratedFunctionTool, tool from strands.tools.registry import ToolRegistry @@ -46,6 +47,23 @@ def test_register_tool_with_similar_name_raises(): ) +def test_get_all_tool_specs_returns_right_tool_specs(): + tool_1 = strands.tool(lambda a: a, name="tool_1") + tool_2 = strands.tool(lambda b: b, name="tool_2") + + tool_registry = ToolRegistry() + + tool_registry.register_tool(tool_1) + tool_registry.register_tool(tool_2) + + tool_specs = tool_registry.get_all_tool_specs() + + assert tool_specs == [ + tool_1.tool_spec, + tool_2.tool_spec, + ] + + def test_scan_module_for_tools(): @tool def tool_function_1(a):