Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..types.content import ContentBlock, Message, Messages
from ..types.exceptions import ContextWindowOverflowException
from ..types.models import Model
from ..types.tools import ToolConfig, ToolResult, ToolUse
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .conversation_manager import (
Expand Down Expand Up @@ -335,15 +335,6 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

@property
def tool_config(self) -> ToolConfig:
"""Get the tool configuration for this agent.

Returns:
The complete tool configuration.
"""
return self.tool_registry.initialize_tool_config()

def __del__(self) -> None:
"""Clean up resources when Agent is garbage collected.

Expand Down
19 changes: 9 additions & 10 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast

from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from ..experimental.hooks.registry import get_registry
Expand All @@ -21,7 +21,7 @@
from ..types.content import Message
from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
from ..types.streaming import Metrics, StopReason
from ..types.tools import ToolGenerator, ToolResult, ToolUse
from ..types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
from .message_processor import clean_orphaned_empty_tool_uses
from .streaming import stream_messages

Expand Down Expand Up @@ -112,10 +112,12 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener
model_id=model_id,
)

tool_specs = agent.tool_registry.get_all_tool_specs()

try:
# TODO: To maintain backwards compatibility, we need to combine the stream event with kwargs before yielding
# to the callback handler. This will be revisited when migrating to strongly typed events.
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, agent.tool_config):
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
if "callback" in event:
yield {"callback": {**event["callback"], **(kwargs if "delta" in event["callback"] else {})}}

Expand Down Expand Up @@ -172,12 +174,6 @@ async def event_loop_cycle(agent: "Agent", kwargs: dict[str, Any]) -> AsyncGener

# If the model is requesting to use tools
if stop_reason == "tool_use":
if agent.tool_config is None:
raise EventLoopException(
Exception("Model requested tool use but no tool config provided"),
kwargs["request_state"],
)

# Handle tool execution
events = _handle_tool_execution(
stop_reason,
Expand Down Expand Up @@ -285,7 +281,10 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
"model": agent.model,
"system_prompt": agent.system_prompt,
"messages": agent.messages,
"tool_config": agent.tool_config,
"tool_config": ToolConfig( # for backwards compatability
tools=[{"toolSpec": tool_spec} for tool_spec in agent.tool_registry.get_all_tool_specs()],
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
),
}
)

Expand Down
9 changes: 4 additions & 5 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
StreamEvent,
Usage,
)
from ..types.tools import ToolConfig, ToolUse
from ..types.tools import ToolSpec, ToolUse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -304,24 +304,23 @@ async def stream_messages(
model: Model,
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
tool_specs: list[ToolSpec],
) -> AsyncGenerator[dict[str, Any], None]:
"""Streams messages to the model and processes the response.

Args:
model: Model provider.
system_prompt: The system prompt to send.
messages: List of messages to send.
tool_config: Configuration for the tools to use.
tool_specs: The list of tool specs.

Returns:
The reason for stopping, the final message, and the usage metrics
"""
logger.debug("model=<%s> | streaming messages", model)

messages = remove_blank_messages_content_text(messages)
tool_specs = [tool["toolSpec"] for tool in tool_config.get("tools", [])] or None if tool_config else None

chunks = model.converse(messages, tool_specs, system_prompt)
chunks = model.converse(messages, tool_specs if tool_specs else None, system_prompt)
async for event in process_stream(chunks, messages):
yield event
17 changes: 6 additions & 11 deletions src/strands/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from strands.tools.decorator import DecoratedFunctionTool

from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec
from ..types.tools import AgentTool, ToolSpec
from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -472,20 +472,15 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
for tool_name, error in tool_import_errors.items():
logger.debug("tool_name=<%s> | import error | %s", tool_name, error)

def initialize_tool_config(self) -> ToolConfig:
"""Initialize tool configuration from tool handler with optional filtering.
def get_all_tool_specs(self) -> list[ToolSpec]:
"""Get all the tool specs for all tools in this registry..

Returns:
Tool config.
A list of ToolSpecs.
"""
all_tools = self.get_all_tools_config()

tools: List[Tool] = [{"toolSpec": tool_spec} for tool_spec in all_tools.values()]

return ToolConfig(
tools=tools,
toolChoice=cast(ToolChoice, {"auto": ToolChoiceAuto()}),
)
tools: List[ToolSpec] = [tool_spec for tool_spec in all_tools.values()]
return tools

def validate_tool_spec(self, tool_spec: ToolSpec) -> None:
"""Validate tool specification against required schema.
Expand Down
4 changes: 2 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_agent__init__tool_loader_format(tool_decorated, tool_module, tool_impor

agent = Agent(tools=[tool_decorated, tool_module, tool_imported])

tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"])
tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs())
exp_tool_names = ["tool_decorated", "tool_imported", "tool_module"]

assert tru_tool_names == exp_tool_names
Expand All @@ -191,7 +191,7 @@ def test_agent__init__tool_loader_dict(tool_module, tool_registry):

agent = Agent(tools=[{"name": "tool_module", "path": tool_module}])

tru_tool_names = sorted(tool_spec["toolSpec"]["name"] for tool_spec in agent.tool_config["tools"])
tru_tool_names = sorted(tool_spec["name"] for tool_spec in agent.tool_registry.get_all_tool_specs())
exp_tool_names = ["tool_module"]

assert tru_tool_names == exp_tool_names
Expand Down
11 changes: 3 additions & 8 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ def messages():
return [{"role": "user", "content": [{"text": "Hello"}]}]


@pytest.fixture
def tool_config():
return {"tools": [{"toolSpec": {"name": "tool_for_testing"}}], "toolChoice": {"auto": {}}}


@pytest.fixture
def tool_registry():
return ToolRegistry()
Expand Down Expand Up @@ -116,13 +111,12 @@ def hook_provider(hook_registry):


@pytest.fixture
def agent(model, system_prompt, messages, tool_config, tool_registry, thread_pool, hook_registry):
def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_registry):
mock = unittest.mock.Mock(name="agent")
mock.config.cache_points = []
mock.model = model
mock.system_prompt = system_prompt
mock.messages = messages
mock.tool_config = tool_config
mock.tool_registry = tool_registry
mock.thread_pool = thread_pool
mock.event_loop_metrics = EventLoopMetrics()
Expand Down Expand Up @@ -298,6 +292,7 @@ async def test_event_loop_cycle_tool_result(
system_prompt,
messages,
tool_stream,
tool_registry,
agenerator,
alist,
):
Expand Down Expand Up @@ -353,7 +348,7 @@ async def test_event_loop_cycle_tool_result(
},
{"role": "assistant", "content": [{"text": "test text"}]},
],
[{"name": "tool_for_testing"}],
tool_registry.get_all_tool_specs(),
"p1",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/strands/event_loop/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ async def test_stream_messages(agenerator, alist):
mock_model,
system_prompt="test prompt",
messages=[{"role": "assistant", "content": [{"text": "a"}, {"text": " \n"}]}],
tool_config=None,
tool_specs=None,
)

tru_events = await alist(stream)
Expand Down
18 changes: 18 additions & 0 deletions tests/strands/tools/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest

import strands
from strands.tools import PythonAgentTool
from strands.tools.decorator import DecoratedFunctionTool, tool
from strands.tools.registry import ToolRegistry
Expand Down Expand Up @@ -46,6 +47,23 @@ def test_register_tool_with_similar_name_raises():
)


def test_get_all_tool_specs_returns_right_tool_specs():
tool_1 = strands.tool(lambda a: a, name="tool_1")
tool_2 = strands.tool(lambda b: b, name="tool_2")

tool_registry = ToolRegistry()

tool_registry.register_tool(tool_1)
tool_registry.register_tool(tool_2)

tool_specs = tool_registry.get_all_tool_specs()

assert tool_specs == [
tool_1.tool_spec,
tool_2.tool_spec,
]


def test_scan_module_for_tools():
@tool
def tool_function_1(a):
Expand Down