Skip to content

Commit ec758cd

Browse files
committed
Support setting the current_tool to None to support method interception/limits
1 parent 64fc331 commit ec758cd

File tree

3 files changed

+179
-30
lines changed

3 files changed

+179
-30
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import TYPE_CHECKING, Any, AsyncGenerator
1515

1616
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
17+
from ..experimental.hooks.registry import get_registry
1718
from ..telemetry.metrics import Trace
1819
from ..telemetry.tracer import get_tracer
1920
from ..tools.executor import run_tools, validate_and_prepare_tools
@@ -288,40 +289,61 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
288289
}
289290
)
290291

291-
before_event = BeforeToolInvocationEvent(
292-
agent=agent,
293-
selected_tool=tool_func,
294-
tool_use=tool,
295-
kwargs=kwargs,
292+
before_event = get_registry(agent).invoke_callbacks(
293+
BeforeToolInvocationEvent(
294+
agent=agent,
295+
selected_tool=tool_func,
296+
tool_use=tool,
297+
kwargs=kwargs,
298+
)
296299
)
297-
agent._hooks.invoke_callbacks(before_event)
298300

299301
try:
300302
selected_tool = before_event.selected_tool
301303
tool_use = before_event.tool_use
302304

303305
# Check if tool exists
304306
if not selected_tool:
305-
logger.error(
306-
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
307-
tool_name,
308-
list(agent.tool_registry.registry.keys()),
309-
)
310-
return {
307+
if tool_func == selected_tool:
308+
logger.error(
309+
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
310+
tool_name,
311+
list(agent.tool_registry.registry.keys()),
312+
)
313+
else:
314+
logger.debug(
315+
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
316+
tool_name,
317+
tool_use,
318+
)
319+
320+
result: ToolResult = {
311321
"toolUseId": str(tool_use.get("toolUseId")),
312322
"status": "error",
313323
"content": [{"text": f"Unknown tool: {tool_name}"}],
314324
}
325+
# for every Before event call, we need to have an AfterEvent call
326+
after_event = get_registry(agent).invoke_callbacks(
327+
AfterToolInvocationEvent(
328+
agent=agent,
329+
selected_tool=selected_tool,
330+
tool_use=tool_use,
331+
kwargs=kwargs,
332+
result=result,
333+
)
334+
)
335+
return after_event.result
315336

316337
result = selected_tool.invoke(tool_use, **kwargs)
317-
after_event = AfterToolInvocationEvent(
318-
agent=agent,
319-
selected_tool=selected_tool,
320-
tool_use=tool_use,
321-
kwargs=kwargs,
322-
result=result,
338+
after_event = get_registry(agent).invoke_callbacks(
339+
AfterToolInvocationEvent(
340+
agent=agent,
341+
selected_tool=selected_tool,
342+
tool_use=tool_use,
343+
kwargs=kwargs,
344+
result=result,
345+
)
323346
)
324-
agent._hooks.invoke_callbacks(after_event)
325347
yield {
326348
"result": after_event.result
327349
} # Placeholder until tool_func becomes a generator from which we can yield from
@@ -334,15 +356,16 @@ def run_tool(agent: "Agent", kwargs: dict[str, Any], tool: ToolUse) -> ToolGener
334356
"status": "error",
335357
"content": [{"text": f"Error: {str(e)}"}],
336358
}
337-
after_event = AfterToolInvocationEvent(
338-
agent=agent,
339-
selected_tool=selected_tool,
340-
tool_use=tool_use,
341-
kwargs=kwargs,
342-
result=error_result,
343-
exception=e,
359+
after_event = get_registry(agent).invoke_callbacks(
360+
AfterToolInvocationEvent(
361+
agent=agent,
362+
selected_tool=selected_tool,
363+
tool_use=tool_use,
364+
kwargs=kwargs,
365+
result=error_result,
366+
exception=e,
367+
)
344368
)
345-
agent._hooks.invoke_callbacks(after_event)
346369
return after_event.result
347370

348371

src/strands/experimental/hooks/registry.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def __setattr__(self, name: str, value: Any) -> None:
7272
T = TypeVar("T", bound=Callable)
7373
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
7474

75+
# Non-contravariant generic for when invoking events
76+
TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent)
77+
7578

7679
class HookProvider(Protocol):
7780
"""Protocol for objects that provide hook callbacks to an agent.
@@ -178,7 +181,7 @@ def register_hooks(self, registry: HookRegistry):
178181
"""
179182
hook.register_hooks(self)
180183

181-
def invoke_callbacks(self, event: TEvent) -> None:
184+
def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent:
182185
"""Invoke all registered callbacks for the given event.
183186
184187
This method finds all callbacks registered for the event's type and
@@ -191,6 +194,9 @@ def invoke_callbacks(self, event: TEvent) -> None:
191194
Raises:
192195
Any exceptions raised by callback functions will propagate to the caller.
193196
197+
Returns:
198+
The event dispatched to registered callbacks.
199+
194200
Example:
195201
```python
196202
event = StartRequestEvent(agent=my_agent)
@@ -200,6 +206,8 @@ def invoke_callbacks(self, event: TEvent) -> None:
200206
for callback in self.get_callbacks_for(event):
201207
callback(event)
202208

209+
return event
210+
203211
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
204212
"""Get callbacks registered for the given event in the appropriate order.
205213
@@ -227,3 +235,18 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No
227235
yield from reversed(callbacks)
228236
else:
229237
yield from callbacks
238+
239+
240+
def get_registry(agent: "Agent") -> HookRegistry:
241+
"""*Experimental*: Get the hooks registry for the provided agent.
242+
243+
This function is available while hooks are in experimental preview.
244+
245+
Args:
246+
agent: The agent whose hook registry should be returned.
247+
248+
Returns:
249+
The HookRegistry for the given agent.
250+
251+
"""
252+
return agent._hooks

tests/strands/event_loop/test_event_loop.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import strands
88
import strands.telemetry
99
from strands.event_loop.event_loop import run_tool
10-
from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookRegistry
10+
from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookProvider, HookRegistry
1111
from strands.telemetry.metrics import EventLoopMetrics
1212
from strands.tools.registry import ToolRegistry
1313
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
@@ -821,6 +821,35 @@ def test_run_tool_hooks(agent, generate, hook_provider, tool_times_2):
821821
)
822822

823823

824+
def test_run_tool_hooks_on_missing_tool(agent, tool_registry, generate, hook_provider):
825+
"""Test that AfterToolInvocation hook is invoked even when tool throws exception."""
826+
process = run_tool(
827+
agent=agent,
828+
tool={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}},
829+
kwargs={},
830+
)
831+
832+
_, result = generate(process)
833+
834+
assert len(hook_provider.events_received) == 2
835+
836+
assert hook_provider.events_received[0] == BeforeToolInvocationEvent(
837+
agent=agent,
838+
selected_tool=None,
839+
tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"},
840+
kwargs=ANY,
841+
)
842+
843+
assert hook_provider.events_received[1] == AfterToolInvocationEvent(
844+
agent=agent,
845+
selected_tool=None,
846+
tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"},
847+
kwargs=ANY,
848+
result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"},
849+
exception=None,
850+
)
851+
852+
824853
def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, generate, hook_provider):
825854
"""Test that AfterToolInvocation hook is invoked even when tool throws exception."""
826855
error = ValueError("Tool failed")
@@ -903,5 +932,79 @@ def modify_hook(event: AfterToolInvocationEvent):
903932

904933
_, result = generate(process)
905934

906-
# Should return modified result instead of original (5 * 2 = 10)
907935
assert result == updated_result
936+
937+
938+
def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, tool_times_2, generate, hook_registry):
939+
"""Test that modifying properties on AfterToolInvocation takes effect."""
940+
941+
updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]}
942+
943+
def modify_hook(event: AfterToolInvocationEvent):
944+
# Modify result to change the output
945+
event.result = updated_result
946+
947+
hook_registry.add_callback(AfterToolInvocationEvent, modify_hook)
948+
949+
process = run_tool(
950+
agent=agent,
951+
tool={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}},
952+
kwargs={},
953+
)
954+
955+
_, result = generate(process)
956+
957+
assert result == updated_result
958+
959+
960+
def test_run_tool_hook_update_result_with_missing_tool(agent, generate, tool_registry, hook_registry):
961+
"""Test that modifying properties on AfterToolInvocation takes effect."""
962+
963+
@strands.tool
964+
def test_quota():
965+
return "9"
966+
967+
tool_registry.register_tool(test_quota)
968+
969+
class ExampleProvider(HookProvider):
970+
def register_hooks(self, registry: "HookRegistry") -> None:
971+
registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call)
972+
registry.add_callback(AfterToolInvocationEvent, self.after_tool_call)
973+
974+
def before_tool_call(self, event: BeforeToolInvocationEvent):
975+
if event.tool_use.get("name") == "test_quota":
976+
event.selected_tool = None
977+
978+
def after_tool_call(self, event: AfterToolInvocationEvent):
979+
if event.tool_use.get("name") == "test_quota":
980+
event.result = {
981+
"status": "error",
982+
"toolUseId": "test",
983+
"content": [{"text": "This tool has been used too many times!"}],
984+
}
985+
986+
hook_registry.add_hook(ExampleProvider())
987+
988+
with patch.object(strands.event_loop.event_loop, "logger") as mock_logger:
989+
process = run_tool(
990+
agent=agent,
991+
tool={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}},
992+
kwargs={},
993+
)
994+
995+
_, result = generate(process)
996+
997+
assert result == {
998+
"status": "error",
999+
"toolUseId": "test",
1000+
"content": [{"text": "This tool has been used too many times!"}],
1001+
}
1002+
1003+
assert mock_logger.debug.call_args_list == [
1004+
call("tool=<%s> | invoking", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}),
1005+
call(
1006+
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
1007+
"test_quota",
1008+
{"toolUseId": "test", "name": "test_quota", "input": {"x": 5}},
1009+
),
1010+
]

0 commit comments

Comments
 (0)