Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 75 additions & 22 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator

from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
from ..experimental.hooks.registry import get_registry
from ..telemetry.metrics import Trace
from ..telemetry.tracer import get_tracer
from ..tools.executor import run_tools, validate_and_prepare_tools
Expand Down Expand Up @@ -271,46 +273,97 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
The final tool result or an error response if the tool fails or is not found.
"""
logger.debug("tool_use=<%s> | streaming", tool_use)
tool_use_id = tool_use["toolUseId"]
tool_name = tool_use["name"]

# Get the tool info
tool_info = agent.tool_registry.dynamic_tools.get(tool_name)
tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name)

# Add standard arguments to kwargs for Python tools
kwargs.update(
{
"model": agent.model,
"system_prompt": agent.system_prompt,
"messages": agent.messages,
"tool_config": agent.tool_config,
}
)

before_event = get_registry(agent).invoke_callbacks(
BeforeToolInvocationEvent(
agent=agent,
selected_tool=tool_func,
tool_use=tool_use,
kwargs=kwargs,
)
)

try:
selected_tool = before_event.selected_tool
tool_use = before_event.tool_use

# Check if tool exists
if not tool_func:
logger.error(
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
tool_name,
list(agent.tool_registry.registry.keys()),
)
return {
"toolUseId": tool_use_id,
if not selected_tool:
if tool_func == selected_tool:
logger.error(
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
tool_name,
list(agent.tool_registry.registry.keys()),
)
else:
logger.debug(
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
tool_name,
str(tool_use.get("toolUseId")),
)

result: ToolResult = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": f"Unknown tool: {tool_name}"}],
}
# Add standard arguments to kwargs for Python tools
kwargs.update(
{
"model": agent.model,
"system_prompt": agent.system_prompt,
"messages": agent.messages,
"tool_config": agent.tool_config,
}
)
# for every Before event call, we need to have an AfterEvent call
after_event = get_registry(agent).invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
selected_tool=selected_tool,
tool_use=tool_use,
kwargs=kwargs,
result=result,
)
)
return after_event.result

result = yield from tool_func.stream(tool_use, **kwargs)
return result
result = yield from selected_tool.stream(tool_use, **kwargs)
after_event = get_registry(agent).invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
selected_tool=selected_tool,
tool_use=tool_use,
kwargs=kwargs,
result=result,
)
)
return after_event.result

except Exception as e:
logger.exception("tool_name=<%s> | failed to process tool", tool_name)
return {
"toolUseId": tool_use_id,
error_result: ToolResult = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": f"Error: {str(e)}"}],
}
after_event = get_registry(agent).invoke_callbacks(
AfterToolInvocationEvent(
agent=agent,
selected_tool=selected_tool,
tool_use=tool_use,
kwargs=kwargs,
result=error_result,
exception=e,
)
)
return after_event.result


async def _handle_tool_execution(
Expand Down
10 changes: 9 additions & 1 deletion src/strands/experimental/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,21 @@ def log_end(self, event: EndRequestEvent) -> None:
type-safe system that supports multiple subscribers per event type.
"""

from .events import AgentInitializedEvent, EndRequestEvent, StartRequestEvent
from .events import (
AfterToolInvocationEvent,
AgentInitializedEvent,
BeforeToolInvocationEvent,
EndRequestEvent,
StartRequestEvent,
)
from .registry import HookCallback, HookEvent, HookProvider, HookRegistry

__all__ = [
"AgentInitializedEvent",
"StartRequestEvent",
"EndRequestEvent",
"BeforeToolInvocationEvent",
"AfterToolInvocationEvent",
"HookEvent",
"HookProvider",
"HookCallback",
Expand Down
64 changes: 60 additions & 4 deletions src/strands/experimental/hooks/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""

from dataclasses import dataclass
from typing import Any, Optional

from ...types.tools import AgentTool, ToolResult, ToolUse
from .registry import HookEvent


Expand Down Expand Up @@ -56,9 +58,63 @@ class EndRequestEvent(HookEvent):

@property
def should_reverse_callbacks(self) -> bool:
"""Return True to invoke callbacks in reverse order for proper cleanup.
"""True to invoke callbacks in reverse order."""
return True


@dataclass
class BeforeToolInvocationEvent(HookEvent):
"""Event triggered before a tool is invoked.

This event is fired just before the agent executes a tool, allowing hook
providers to inspect, modify, or replace the tool that will be executed.
The selected_tool can be modified by hook callbacks to change which tool
gets executed.

Attributes:
selected_tool: The tool that will be invoked. Can be modified by hooks
to change which tool gets executed. This may be None if tool lookup failed.
tool_use: The tool parameters that will be passed to selected_tool.
kwargs: Keyword arguments that will be passed to the tool.
"""

selected_tool: Optional[AgentTool]
tool_use: ToolUse
kwargs: dict[str, Any]

def _can_write(self, name: str) -> bool:
return name in ["selected_tool", "tool_use"]


@dataclass
class AfterToolInvocationEvent(HookEvent):
"""Event triggered after a tool invocation completes.

Returns:
True, indicating callbacks should be invoked in reverse order.
"""
This event is fired after the agent has finished executing a tool,
regardless of whether the execution was successful or resulted in an error.
Hook providers can use this event for cleanup, logging, or post-processing.

Note: This event uses reverse callback ordering, meaning callbacks registered
later will be invoked first during cleanup.

Attributes:
selected_tool: The tool that was invoked. It may be None if tool lookup failed.
tool_use: The tool parameters that were passed to the tool invoked.
kwargs: Keyword arguments that were passed to the tool
result: The result of the tool invocation. Either a ToolResult on success
or an Exception if the tool execution failed.
"""

selected_tool: Optional[AgentTool]
tool_use: ToolUse
kwargs: dict[str, Any]
result: ToolResult
exception: Optional[Exception] = None

def _can_write(self, name: str) -> bool:
return name == "result"

@property
def should_reverse_callbacks(self) -> bool:
"""True to invoke callbacks in reverse order."""
return True
60 changes: 57 additions & 3 deletions src/strands/experimental/hooks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Generator, Generic, Protocol, Type, TypeVar
from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar

if TYPE_CHECKING:
from ...agent import Agent
Expand All @@ -34,9 +34,43 @@ def should_reverse_callbacks(self) -> bool:
"""
return False

def _can_write(self, name: str) -> bool:
"""Check if the given property can be written to.

Args:
name: The name of the property to check.

Returns:
True if the property can be written to, False otherwise.
"""
return False

def __post_init__(self) -> None:
"""Disallow writes to non-approved properties."""
# This is needed as otherwise the class can't be initialized at all, so we trigger
# this after class initialization
super().__setattr__("_disallow_writes", True)

def __setattr__(self, name: str, value: Any) -> None:
"""Prevent setting attributes on hook events.

Raises:
AttributeError: Always raised to prevent setting attributes on hook events.
"""
# Allow setting attributes:
# - during init (when __dict__) doesn't exist
# - if the subclass specifically said the property is writable
if not hasattr(self, "_disallow_writes") or self._can_write(name):
return super().__setattr__(name, value)

raise AttributeError(f"Property {name} is not writable")


T = TypeVar("T", bound=Callable)
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
"""Generic for adding callback handlers - contravariant to allow adding handlers which take in base classes."""

TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent)
"""Generic for invoking events - non-contravariant to enable returning events."""


class HookProvider(Protocol):
Expand Down Expand Up @@ -144,7 +178,7 @@ def register_hooks(self, registry: HookRegistry):
"""
hook.register_hooks(self)

def invoke_callbacks(self, event: TEvent) -> None:
def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent:
"""Invoke all registered callbacks for the given event.

This method finds all callbacks registered for the event's type and
Expand All @@ -157,6 +191,9 @@ def invoke_callbacks(self, event: TEvent) -> None:
Raises:
Any exceptions raised by callback functions will propagate to the caller.

Returns:
The event dispatched to registered callbacks.

Example:
```python
event = StartRequestEvent(agent=my_agent)
Expand All @@ -166,6 +203,8 @@ def invoke_callbacks(self, event: TEvent) -> None:
for callback in self.get_callbacks_for(event):
callback(event)

return event

def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
"""Get callbacks registered for the given event in the appropriate order.

Expand Down Expand Up @@ -193,3 +232,18 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No
yield from reversed(callbacks)
else:
yield from callbacks


def get_registry(agent: "Agent") -> HookRegistry:
"""*Experimental*: Get the hooks registry for the provided agent.

This function is available while hooks are in experimental preview.

Args:
agent: The agent whose hook registry should be returned.

Returns:
The HookRegistry for the given agent.

"""
return agent._hooks
11 changes: 5 additions & 6 deletions tests/fixtures/mock_hook_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import deque
from typing import Type
from typing import Iterator, Tuple, Type

from strands.experimental.hooks import HookEvent, HookProvider, HookRegistry

Expand All @@ -9,12 +8,12 @@ def __init__(self, event_types: list[Type]):
self.events_received = []
self.events_types = event_types

def get_events(self) -> deque[HookEvent]:
return deque(self.events_received)
def get_events(self) -> Tuple[int, Iterator[HookEvent]]:
return len(self.events_received), iter(self.events_received)

def register_hooks(self, registry: HookRegistry) -> None:
for event_type in self.events_types:
registry.add_callback(event_type, self._add_event)
registry.add_callback(event_type, self.add_event)

def _add_event(self, event: HookEvent) -> None:
def add_event(self, event: HookEvent) -> None:
self.events_received.append(event)
Loading