|
7 | 7 | import strands |
8 | 8 | import strands.telemetry |
9 | 9 | from strands.event_loop.event_loop import run_tool |
10 | | -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookRegistry |
| 10 | +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookProvider, HookRegistry |
11 | 11 | from strands.telemetry.metrics import EventLoopMetrics |
12 | 12 | from strands.tools.registry import ToolRegistry |
13 | 13 | from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException |
@@ -824,6 +824,35 @@ def test_run_tool_hooks(agent, generate, hook_provider, tool_times_2): |
824 | 824 | ) |
825 | 825 |
|
826 | 826 |
|
| 827 | +def test_run_tool_hooks_on_missing_tool(agent, tool_registry, generate, hook_provider): |
| 828 | + """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" |
| 829 | + process = run_tool( |
| 830 | + agent=agent, |
| 831 | + tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, |
| 832 | + kwargs={}, |
| 833 | + ) |
| 834 | + |
| 835 | + _, result = generate(process) |
| 836 | + |
| 837 | + assert len(hook_provider.events_received) == 2 |
| 838 | + |
| 839 | + assert hook_provider.events_received[0] == BeforeToolInvocationEvent( |
| 840 | + agent=agent, |
| 841 | + selected_tool=None, |
| 842 | + tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, |
| 843 | + kwargs=ANY, |
| 844 | + ) |
| 845 | + |
| 846 | + assert hook_provider.events_received[1] == AfterToolInvocationEvent( |
| 847 | + agent=agent, |
| 848 | + selected_tool=None, |
| 849 | + tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, |
| 850 | + kwargs=ANY, |
| 851 | + result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"}, |
| 852 | + exception=None, |
| 853 | + ) |
| 854 | + |
| 855 | + |
827 | 856 | def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, generate, hook_provider): |
828 | 857 | """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" |
829 | 858 | error = ValueError("Tool failed") |
@@ -906,5 +935,79 @@ def modify_hook(event: AfterToolInvocationEvent): |
906 | 935 |
|
907 | 936 | _, result = generate(process) |
908 | 937 |
|
909 | | - # Should return modified result instead of original (5 * 2 = 10) |
910 | 938 | assert result == updated_result |
| 939 | + |
| 940 | + |
| 941 | +def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, tool_times_2, generate, hook_registry): |
| 942 | + """Test that modifying properties on AfterToolInvocation takes effect.""" |
| 943 | + |
| 944 | + updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} |
| 945 | + |
| 946 | + def modify_hook(event: AfterToolInvocationEvent): |
| 947 | + # Modify result to change the output |
| 948 | + event.result = updated_result |
| 949 | + |
| 950 | + hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) |
| 951 | + |
| 952 | + process = run_tool( |
| 953 | + agent=agent, |
| 954 | + tool_use={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, |
| 955 | + kwargs={}, |
| 956 | + ) |
| 957 | + |
| 958 | + _, result = generate(process) |
| 959 | + |
| 960 | + assert result == updated_result |
| 961 | + |
| 962 | + |
| 963 | +def test_run_tool_hook_update_result_with_missing_tool(agent, generate, tool_registry, hook_registry): |
| 964 | + """Test that modifying properties on AfterToolInvocation takes effect.""" |
| 965 | + |
| 966 | + @strands.tool |
| 967 | + def test_quota(): |
| 968 | + return "9" |
| 969 | + |
| 970 | + tool_registry.register_tool(test_quota) |
| 971 | + |
| 972 | + class ExampleProvider(HookProvider): |
| 973 | + def register_hooks(self, registry: "HookRegistry") -> None: |
| 974 | + registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call) |
| 975 | + registry.add_callback(AfterToolInvocationEvent, self.after_tool_call) |
| 976 | + |
| 977 | + def before_tool_call(self, event: BeforeToolInvocationEvent): |
| 978 | + if event.tool_use.get("name") == "test_quota": |
| 979 | + event.selected_tool = None |
| 980 | + |
| 981 | + def after_tool_call(self, event: AfterToolInvocationEvent): |
| 982 | + if event.tool_use.get("name") == "test_quota": |
| 983 | + event.result = { |
| 984 | + "status": "error", |
| 985 | + "toolUseId": "test", |
| 986 | + "content": [{"text": "This tool has been used too many times!"}], |
| 987 | + } |
| 988 | + |
| 989 | + hook_registry.add_hook(ExampleProvider()) |
| 990 | + |
| 991 | + with patch.object(strands.event_loop.event_loop, "logger") as mock_logger: |
| 992 | + process = run_tool( |
| 993 | + agent=agent, |
| 994 | + tool_use={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, |
| 995 | + kwargs={}, |
| 996 | + ) |
| 997 | + |
| 998 | + _, result = generate(process) |
| 999 | + |
| 1000 | + assert result == { |
| 1001 | + "status": "error", |
| 1002 | + "toolUseId": "test", |
| 1003 | + "content": [{"text": "This tool has been used too many times!"}], |
| 1004 | + } |
| 1005 | + |
| 1006 | + assert mock_logger.debug.call_args_list == [ |
| 1007 | + call("tool_use=<%s> | streaming", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), |
| 1008 | + call( |
| 1009 | + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", |
| 1010 | + "test_quota", |
| 1011 | + {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, |
| 1012 | + ), |
| 1013 | + ] |
0 commit comments