1- import unittest .mock
2- from unittest .mock import call
1+ from unittest .mock import ANY , Mock , call , patch
32
43import pytest
54from pydantic import BaseModel
65
76import strands
87from strands import Agent
9- from strands .experimental .hooks import AgentInitializedEvent , EndRequestEvent , StartRequestEvent
8+ from strands .experimental .hooks import (
9+ AfterToolInvocation ,
10+ AgentInitializedEvent ,
11+ BeforeToolInvocation ,
12+ EndRequestEvent ,
13+ StartRequestEvent ,
14+ )
1015from strands .types .content import Messages
1116from tests .fixtures .mock_hook_provider import MockHookProvider
1217from tests .fixtures .mocked_model_provider import MockedModelProvider
1318
1419
1520@pytest .fixture
1621def hook_provider ():
17- return MockHookProvider ([AgentInitializedEvent , StartRequestEvent , EndRequestEvent ])
22+ return MockHookProvider (
23+ [AgentInitializedEvent , StartRequestEvent , EndRequestEvent , AfterToolInvocation , BeforeToolInvocation ]
24+ )
1825
1926
2027@pytest .fixture
@@ -71,7 +78,7 @@ class User(BaseModel):
7178 return User (name = "Jane Doe" , age = 30 )
7279
7380
74- @unittest . mock . patch ("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks" )
81+ @patch ("strands.experimental.hooks.registry.HookRegistry.invoke_callbacks" )
7582def test_agent__init__hooks (mock_invoke_callbacks ):
7683 """Verify that the AgentInitializedEvent is emitted on Agent construction."""
7784 agent = Agent ()
@@ -87,9 +94,19 @@ def test_agent__call__hooks(agent, hook_provider, agent_tool, tool_use):
8794 agent ("test message" )
8895
8996 events = hook_provider .get_events ()
90- assert len (events ) == 2
9197
98+ assert len (events ) == 4
9299 assert events .popleft () == StartRequestEvent (agent = agent )
100+ assert events .popleft () == BeforeToolInvocation (
101+ agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
102+ )
103+ assert events .popleft () == AfterToolInvocation (
104+ agent = agent ,
105+ selected_tool = agent_tool ,
106+ tool_use = tool_use ,
107+ kwargs = ANY ,
108+ result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
109+ )
93110 assert events .popleft () == EndRequestEvent (agent = agent )
94111
95112
@@ -105,16 +122,26 @@ async def test_agent_stream_async_hooks(agent, hook_provider, agent_tool, tool_u
105122 pass
106123
107124 events = hook_provider .get_events ()
108- assert len (events ) == 2
109125
126+ assert len (events ) == 4
110127 assert events .popleft () == StartRequestEvent (agent = agent )
128+ assert events .popleft () == BeforeToolInvocation (
129+ agent = agent , selected_tool = agent_tool , tool_use = tool_use , kwargs = ANY
130+ )
131+ assert events .popleft () == AfterToolInvocation (
132+ agent = agent ,
133+ selected_tool = agent_tool ,
134+ tool_use = tool_use ,
135+ kwargs = ANY ,
136+ result = {"content" : [{"text" : "!loot a dekovni I" }], "status" : "success" , "toolUseId" : "123" },
137+ )
111138 assert events .popleft () == EndRequestEvent (agent = agent )
112139
113140
114141def test_agent_structured_output_hooks (agent , hook_provider , user , agenerator ):
115142 """Verify that the correct hook events are emitted as part of structured_output."""
116143
117- agent .model .structured_output = unittest . mock . Mock (return_value = agenerator ([{"output" : user }]))
144+ agent .model .structured_output = Mock (return_value = agenerator ([{"output" : user }]))
118145 agent .structured_output (type (user ), "example prompt" )
119146
120147 assert hook_provider .events_received == [StartRequestEvent (agent = agent ), EndRequestEvent (agent = agent )]
@@ -124,7 +151,7 @@ def test_agent_structured_output_hooks(agent, hook_provider, user, agenerator):
124151async def test_agent_structured_async_output_hooks (agent , hook_provider , user , agenerator ):
125152 """Verify that the correct hook events are emitted as part of structured_output_async."""
126153
127- agent .model .structured_output = unittest . mock . Mock (return_value = agenerator ([{"output" : user }]))
154+ agent .model .structured_output = Mock (return_value = agenerator ([{"output" : user }]))
128155 await agent .structured_output_async (type (user ), "example prompt" )
129156
130157 assert hook_provider .events_received == [StartRequestEvent (agent = agent ), EndRequestEvent (agent = agent )]
0 commit comments