|
7 | 7 | import strands |
8 | 8 | import strands.telemetry |
9 | 9 | from strands.event_loop.event_loop import run_tool |
10 | | -from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookRegistry |
| 10 | +from strands.experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent, HookProvider, HookRegistry |
11 | 11 | from strands.telemetry.metrics import EventLoopMetrics |
12 | 12 | from strands.tools.registry import ToolRegistry |
13 | 13 | from strands.types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException |
@@ -821,6 +821,35 @@ def test_run_tool_hooks(agent, generate, hook_provider, tool_times_2): |
821 | 821 | ) |
822 | 822 |
|
823 | 823 |
|
| 824 | +def test_run_tool_hooks_on_missing_tool(agent, tool_registry, generate, hook_provider): |
| 825 | + """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" |
| 826 | + process = run_tool( |
| 827 | + agent=agent, |
| 828 | + tool={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, |
| 829 | + kwargs={}, |
| 830 | + ) |
| 831 | + |
| 832 | + _, result = generate(process) |
| 833 | + |
| 834 | + assert len(hook_provider.events_received) == 2 |
| 835 | + |
| 836 | + assert hook_provider.events_received[0] == BeforeToolInvocationEvent( |
| 837 | + agent=agent, |
| 838 | + selected_tool=None, |
| 839 | + tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, |
| 840 | + kwargs=ANY, |
| 841 | + ) |
| 842 | + |
| 843 | + assert hook_provider.events_received[1] == AfterToolInvocationEvent( |
| 844 | + agent=agent, |
| 845 | + selected_tool=None, |
| 846 | + tool_use={"input": {"x": 5}, "name": "missing_tool", "toolUseId": "test"}, |
| 847 | + kwargs=ANY, |
| 848 | + result={"content": [{"text": "Unknown tool: missing_tool"}], "status": "error", "toolUseId": "test"}, |
| 849 | + exception=None, |
| 850 | + ) |
| 851 | + |
| 852 | + |
824 | 853 | def test_run_tool_hook_after_tool_invocation_on_exception(agent, tool_registry, generate, hook_provider): |
825 | 854 | """Test that AfterToolInvocation hook is invoked even when tool throws exception.""" |
826 | 855 | error = ValueError("Tool failed") |
@@ -903,5 +932,79 @@ def modify_hook(event: AfterToolInvocationEvent): |
903 | 932 |
|
904 | 933 | _, result = generate(process) |
905 | 934 |
|
906 | | - # Should return modified result instead of original (5 * 2 = 10) |
907 | 935 | assert result == updated_result |
| 936 | + |
| 937 | + |
| 938 | +def test_run_tool_hook_after_tool_invocation_updates_with_missing_tool(agent, tool_times_2, generate, hook_registry): |
| 939 | + """Test that modifying properties on AfterToolInvocation takes effect.""" |
| 940 | + |
| 941 | + updated_result = {"toolUseId": "modified", "status": "success", "content": [{"text": "modified_result"}]} |
| 942 | + |
| 943 | + def modify_hook(event: AfterToolInvocationEvent): |
| 944 | + # Modify result to change the output |
| 945 | + event.result = updated_result |
| 946 | + |
| 947 | + hook_registry.add_callback(AfterToolInvocationEvent, modify_hook) |
| 948 | + |
| 949 | + process = run_tool( |
| 950 | + agent=agent, |
| 951 | + tool={"toolUseId": "test", "name": "missing_tool", "input": {"x": 5}}, |
| 952 | + kwargs={}, |
| 953 | + ) |
| 954 | + |
| 955 | + _, result = generate(process) |
| 956 | + |
| 957 | + assert result == updated_result |
| 958 | + |
| 959 | + |
| 960 | +def test_run_tool_hook_update_result_with_missing_tool(agent, generate, tool_registry, hook_registry): |
| 961 | + """Test that modifying properties on AfterToolInvocation takes effect.""" |
| 962 | + |
| 963 | + @strands.tool |
| 964 | + def test_quota(): |
| 965 | + return "9" |
| 966 | + |
| 967 | + tool_registry.register_tool(test_quota) |
| 968 | + |
| 969 | + class ExampleProvider(HookProvider): |
| 970 | + def register_hooks(self, registry: "HookRegistry") -> None: |
| 971 | + registry.add_callback(BeforeToolInvocationEvent, self.before_tool_call) |
| 972 | + registry.add_callback(AfterToolInvocationEvent, self.after_tool_call) |
| 973 | + |
| 974 | + def before_tool_call(self, event: BeforeToolInvocationEvent): |
| 975 | + if event.tool_use.get("name") == "test_quota": |
| 976 | + event.selected_tool = None |
| 977 | + |
| 978 | + def after_tool_call(self, event: AfterToolInvocationEvent): |
| 979 | + if event.tool_use.get("name") == "test_quota": |
| 980 | + event.result = { |
| 981 | + "status": "error", |
| 982 | + "toolUseId": "test", |
| 983 | + "content": [{"text": "This tool has been used too many times!"}], |
| 984 | + } |
| 985 | + |
| 986 | + hook_registry.add_hook(ExampleProvider()) |
| 987 | + |
| 988 | + with patch.object(strands.event_loop.event_loop, "logger") as mock_logger: |
| 989 | + process = run_tool( |
| 990 | + agent=agent, |
| 991 | + tool={"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, |
| 992 | + kwargs={}, |
| 993 | + ) |
| 994 | + |
| 995 | + _, result = generate(process) |
| 996 | + |
| 997 | + assert result == { |
| 998 | + "status": "error", |
| 999 | + "toolUseId": "test", |
| 1000 | + "content": [{"text": "This tool has been used too many times!"}], |
| 1001 | + } |
| 1002 | + |
| 1003 | + assert mock_logger.debug.call_args_list == [ |
| 1004 | + call("tool=<%s> | invoking", {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}), |
| 1005 | + call( |
| 1006 | + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", |
| 1007 | + "test_quota", |
| 1008 | + {"toolUseId": "test", "name": "test_quota", "input": {"x": 5}}, |
| 1009 | + ), |
| 1010 | + ] |
0 commit comments