Skip to content

Commit dd3478b

Browse files
committed
Support setting the current_tool to None to support method interception/limits
1 parent 8127c66 commit dd3478b

File tree

3 files changed

+179
-30
lines changed

3 files changed

+179
-30
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import TYPE_CHECKING, Any, AsyncGenerator
1515

1616
from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
17+
from ..experimental.hooks.registry import get_registry
1718
from ..telemetry.metrics import Trace
1819
from ..telemetry.tracer import get_tracer
1920
from ..tools.executor import run_tools, validate_and_prepare_tools
@@ -288,40 +289,61 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
288289
}
289290
)
290291

291-
before_event = BeforeToolInvocationEvent(
292-
agent=agent,
293-
selected_tool=tool_func,
294-
tool_use=tool_use,
295-
kwargs=kwargs,
292+
before_event = get_registry(agent).invoke_callbacks(
293+
BeforeToolInvocationEvent(
294+
agent=agent,
295+
selected_tool=tool_func,
296+
tool_use=tool_use,
297+
kwargs=kwargs,
298+
)
296299
)
297-
agent._hooks.invoke_callbacks(before_event)
298300

299301
try:
300302
selected_tool = before_event.selected_tool
301303
tool_use = before_event.tool_use
302304

303305
# Check if tool exists
304306
if not selected_tool:
305-
logger.error(
306-
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
307-
tool_name,
308-
list(agent.tool_registry.registry.keys()),
309-
)
310-
return {
307+
if tool_func == selected_tool:
308+
logger.error(
309+
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
310+
tool_name,
311+
list(agent.tool_registry.registry.keys()),
312+
)
313+
else:
314+
logger.debug(
315+
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
316+
tool_name,
317+
tool_use,
318+
)
319+
320+
result: ToolResult = {
311321
"toolUseId": str(tool_use.get("toolUseId")),
312322
"status": "error",
313323
"content": [{"text": f"Unknown tool: {tool_name}"}],
314324
}
325+
# for every Before event call, we need to have an AfterEvent call
326+
after_event = get_registry(agent).invoke_callbacks(
327+
AfterToolInvocationEvent(
328+
agent=agent,
329+
selected_tool=selected_tool,
330+
tool_use=tool_use,
331+
kwargs=kwargs,
332+
result=result,
333+
)
334+
)
335+
return after_event.result
315336

316337
result = yield from selected_tool.stream(tool_use, **kwargs)
317-
after_event = AfterToolInvocationEvent(
318-
agent=agent,
319-
selected_tool=selected_tool,
320-
tool_use=tool_use,
321-
kwargs=kwargs,
322-
result=result,
338+
after_event = get_registry(agent).invoke_callbacks(
339+
AfterToolInvocationEvent(
340+
agent=agent,
341+
selected_tool=selected_tool,
342+
tool_use=tool_use,
343+
kwargs=kwargs,
344+
result=result,
345+
)
323346
)
324-
agent._hooks.invoke_callbacks(after_event)
325347
return after_event.result
326348

327349
except Exception as e:
@@ -331,15 +353,16 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG
331353
"status": "error",
332354
"content": [{"text": f"Error: {str(e)}"}],
333355
}
334-
after_event = AfterToolInvocationEvent(
335-
agent=agent,
336-
selected_tool=selected_tool,
337-
tool_use=tool_use,
338-
kwargs=kwargs,
339-
result=error_result,
340-
exception=e,
356+
after_event = get_registry(agent).invoke_callbacks(
357+
AfterToolInvocationEvent(
358+
agent=agent,
359+
selected_tool=selected_tool,
360+
tool_use=tool_use,
361+
kwargs=kwargs,
362+
result=error_result,
363+
exception=e,
364+
)
341365
)
342-
agent._hooks.invoke_callbacks(after_event)
343366
return after_event.result
344367

345368

src/strands/experimental/hooks/registry.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def __setattr__(self, name: str, value: Any) -> None:
7272
T = TypeVar("T", bound=Callable)
7373
TEvent = TypeVar("TEvent", bound=HookEvent, contravariant=True)
7474

75+
# Non-contravariant generic for when invoking events
76+
TInvokeEvent = TypeVar("TInvokeEvent", bound=HookEvent)
77+
7578

7679
class HookProvider(Protocol):
7780
"""Protocol for objects that provide hook callbacks to an agent.
@@ -178,7 +181,7 @@ def register_hooks(self, registry: HookRegistry):
178181
"""
179182
hook.register_hooks(self)
180183

181-
def invoke_callbacks(self, event: TEvent) -> None:
184+
def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent:
182185
"""Invoke all registered callbacks for the given event.
183186
184187
This method finds all callbacks registered for the event's type and
@@ -191,6 +194,9 @@ def invoke_callbacks(self, event: TEvent) -> None:
191194
Raises:
192195
Any exceptions raised by callback functions will propagate to the caller.
193196
197+
Returns:
198+
The event dispatched to registered callbacks.
199+
194200
Example:
195201
```python
196202
event = StartRequestEvent(agent=my_agent)
@@ -200,6 +206,8 @@ def invoke_callbacks(self, event: TEvent) -> None:
200206
for callback in self.get_callbacks_for(event):
201207
callback(event)
202208

209+
return event
210+
203211
def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], None, None]:
204212
"""Get callbacks registered for the given event in the appropriate order.
205213
@@ -227,3 +235,18 @@ def get_callbacks_for(self, event: TEvent) -> Generator[HookCallback[TEvent], No
227235
yield from reversed(callbacks)
228236
else:
229237
yield from callbacks
238+
239+
240+
def get_registry(agent: "Agent") -> HookRegistry:
241+
"""*Experimental*: Get the hooks registry for the provided agent.
242+
243+
This function is available while hooks are in experimental preview.
244+
245+
Args:
246+
agent: The agent whose hook registry should be returned.
247+
248+
Returns:
249+
The HookRegistry for the given agent.
250+
251+
"""
252+
return agent._hooks

tests/strands/event_loop/test_event_loop.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import strands
88
import strands.telemetry
99
from strands.event_loop.event_loop import run_tool
10-
from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookRegistry
10+
from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookProvider, HookRegistry
1111
from strands.telemetry.metrics import EventLoopMetrics
1212
from strands.tools.registry import ToolRegistry
1313
from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
@@ -824,6 +824,35 @@ def test_run_tool_hooks(agent, generate, hook_provider, tool_times_2):
824824
)
825825

826826

827+
def test_run_tool_hooks_on_missing_tool(agent, tool_registry, generate, hook_provider):
828+
"""Test that AfterToolInvocation hook is invoked even when tool throws exception."""
829+
process = run_tool(
830+
agent=agent,
831+
tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}},
832+
kwargs={},
833+
)
834+
835+
_, result = generate(process)
836+
837+
assert len(hook_provider.events_received) == 2
838+
839+
assert hook_provider.events_received[0] == BeforeToolInvocationEvent(
840+
agent=agent,
841+
selected_tool=None,
842+
tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"},
843+
kwargs=ANY,
844+
)
845+
846+
assert hook_provider.events_received[1] == AfterToolInvocationEvent(
847+
agent=agent,
848+
selected_tool=None,
849+
tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"},
850+
kwargs=ANY,
851+
result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"},
852+
exception=None,
853+
)
854+
855+
827856
def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, generate, hook_provider):
828857
"""Test that AfterToolInvocation hook is invoked even when tool throws exception."""
829858
error = ValueError("Tool failed")
@@ -906,5 +935,79 @@ def modify_hook(event: AfterToolInvocationEvent):
906935

907936
_, result = generate(process)
908937

909-
# Should return modified result instead of original (5 * 2 = 10)
910938
assert result == updated_result
939+
940+
941+
def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, tool_times_2, generate, hook_registry):
942+
"""Test that modifying properties on AfterToolInvocation takes effect."""
943+
944+
updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]}
945+
946+
def modify_hook(event: AfterToolInvocationEvent):
947+
# Modify result to change the output
948+
event.result = updated_result
949+
950+
hook_registry.add_callback(AfterToolInvocationEvent, modify_hook)
951+
952+
process = run_tool(
953+
agent=agent,
954+
tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}},
955+
kwargs={},
956+
)
957+
958+
_, result = generate(process)
959+
960+
assert result == updated_result
961+
962+
963+
def test_run_tool_hook_update_result_with_missing_tool(agent, generate, tool_registry, hook_registry):
964+
"""Test that modifying properties on AfterToolInvocation takes effect."""
965+
966+
@strands.tool
967+
def test_quota():
968+
return "9"
969+
970+
tool_registry.register_tool(test_quota)
971+
972+
class ExampleProvider(HookProvider):
973+
def register_hooks(self, registry: "HookRegistry") -> None:
974+
registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call)
975+
registry.add_callback(AfterToolInvocationEvent, self.after_tool_call)
976+
977+
def before_tool_call(self, event: BeforeToolInvocationEvent):
978+
if event.tool_use.get("name") == "test_quota":
979+
event.selected_tool = None
980+
981+
def after_tool_call(self, event: AfterToolInvocationEvent):
982+
if event.tool_use.get("name") == "test_quota":
983+
event.result = {
984+
"status": "error",
985+
"toolUseId": "test",
986+
"content": [{"text": "This tool has been used too many times!"}],
987+
}
988+
989+
hook_registry.add_hook(ExampleProvider())
990+
991+
with patch.object(strands.event_loop.event_loop, "logger") as mock_logger:
992+
process = run_tool(
993+
agent=agent,
994+
tool_use={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}},
995+
kwargs={},
996+
)
997+
998+
_, result = generate(process)
999+
1000+
assert result == {
1001+
"status": "error",
1002+
"toolUseId": "test",
1003+
"content": [{"text": "This tool has been used too many times!"}],
1004+
}
1005+
1006+
assert mock_logger.debug.call_args_list == [
1007+
call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}),
1008+
call(
1009+
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
1010+
"test_quota",
1011+
{"toolUseId": "test", "name": "test_quota", "input": {"x": 5}},
1012+
),
1013+
]

0 commit comments

Comments
 (0)