Skip to content

Commit 062bc8b

Browse files
committed
feat(multiagent): Swarm - update unit tests
1 parent 0309a9d commit 062bc8b

File tree

2 files changed

+89
-75
lines changed

2 files changed

+89
-75
lines changed

src/strands/multiagent/swarm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ async def execute_async(self, task: str | list[ContentBlock]) -> SwarmResult:
261261
try:
262262
logger.info("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
263263
logger.info(
264-
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | SwarmConfig",
264+
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config",
265265
self.max_handoffs,
266266
self.max_iterations,
267267
self.execution_timeout,

tests/strands/multiagent/test_swarm.py

Lines changed: 88 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import time
23
from unittest.mock import MagicMock, Mock
34

@@ -62,41 +63,6 @@ async def mock_stream_async(*args, **kwargs):
6263
return agent
6364

6465

65-
def create_handoff_agent(name, target_agent_name, response_text="Handing off"):
66-
"""Create a mock agent that performs handoffs."""
67-
agent = create_mock_agent(name, response_text, complete_after_calls=999) # Never complete naturally
68-
69-
def create_handoff_result():
70-
agent._call_count += 1
71-
# Perform handoff after first call
72-
if agent._call_count == 1 and agent._swarm_ref:
73-
target_node = agent._swarm_ref.nodes.get(target_agent_name)
74-
if target_node:
75-
agent._swarm_ref._handle_handoff(
76-
target_node, f"Handing off to {target_agent_name}", {"handoff_context": "test_data"}
77-
)
78-
79-
return AgentResult(
80-
message={"role": "assistant", "content": [{"text": response_text}]},
81-
stop_reason="end_turn",
82-
state={},
83-
metrics=Mock(
84-
accumulated_usage={"inputTokens": 5, "outputTokens": 10, "totalTokens": 15},
85-
accumulated_metrics={"latencyMs": 50.0},
86-
),
87-
)
88-
89-
agent.return_value = create_handoff_result()
90-
agent.__call__ = Mock(side_effect=create_handoff_result)
91-
92-
async def mock_stream_async(*args, **kwargs):
93-
result = create_handoff_result()
94-
yield {"result": result}
95-
96-
agent.stream_async = MagicMock(side_effect=mock_stream_async)
97-
return agent
98-
99-
10066
@pytest.fixture
10167
def mock_agents():
10268
"""Create a set of mock agents for testing."""
@@ -126,8 +92,7 @@ def mock_swarm(mock_agents):
12692
return swarm
12793

12894

129-
@pytest.mark.asyncio
130-
async def test_swarm_structure_and_nodes(mock_swarm, mock_agents):
95+
def test_swarm_structure_and_nodes(mock_swarm, mock_agents):
13196
"""Test swarm structure and SwarmNode properties."""
13297
# Test swarm structure
13398
assert len(mock_swarm.nodes) == 3
@@ -181,32 +146,6 @@ def test_shared_context(mock_swarm):
181146
shared_context.add_context(coordinator_node, "key", lambda x: x)
182147

183148

184-
@pytest.mark.asyncio
185-
async def test_swarm_execution_async(mock_swarm, mock_agents):
186-
"""Test asynchronous swarm execution."""
187-
# Execute swarm with multi-modal content
188-
multi_modal_task = [ContentBlock(text="Analyze this task"), ContentBlock(text="Additional context")]
189-
result = await mock_swarm.execute_async(multi_modal_task)
190-
191-
# Verify execution results
192-
assert result.status == Status.COMPLETED
193-
assert result.execution_count >= 1
194-
assert len(result.results) >= 1
195-
assert result.execution_time >= 0
196-
197-
# Verify agent was called
198-
mock_agents["coordinator"].stream_async.assert_called()
199-
200-
# Verify metrics aggregation
201-
assert result.accumulated_usage["totalTokens"] >= 0
202-
assert result.accumulated_metrics["latencyMs"] >= 0
203-
204-
# Verify result type
205-
assert isinstance(result, SwarmResult)
206-
assert hasattr(result, "node_history")
207-
assert len(result.node_history) >= 1
208-
209-
210149
def test_swarm_state_should_continue(mock_swarm):
211150
"""Test SwarmState should_continue method with various scenarios."""
212151
coordinator_node = mock_swarm.nodes["coordinator"]
@@ -273,6 +212,31 @@ def test_swarm_state_should_continue(mock_swarm):
273212
assert "Repetitive handoff" in reason
274213

275214

215+
@pytest.mark.asyncio
216+
async def test_swarm_execution_async(mock_swarm, mock_agents):
217+
"""Test asynchronous swarm execution."""
218+
# Execute swarm
219+
task = [ContentBlock(text="Analyze this task"), ContentBlock(text="Additional context")]
220+
result = await mock_swarm.execute_async(task)
221+
222+
# Verify execution results
223+
assert result.status == Status.COMPLETED
224+
assert result.execution_count == 1
225+
assert len(result.results) == 1
226+
227+
# Verify agent was called
228+
mock_agents["coordinator"].stream_async.assert_called()
229+
230+
# Verify metrics aggregation
231+
assert result.accumulated_usage["totalTokens"] >= 0
232+
assert result.accumulated_metrics["latencyMs"] >= 0
233+
234+
# Verify result type
235+
assert isinstance(result, SwarmResult)
236+
assert hasattr(result, "node_history")
237+
assert len(result.node_history) == 1
238+
239+
276240
def test_swarm_synchronous_execution(mock_agents):
277241
"""Test synchronous swarm execution using __call__ method."""
278242
agents = list(mock_agents.values())
@@ -293,8 +257,8 @@ def test_swarm_synchronous_execution(mock_agents):
293257

294258
# Verify execution results
295259
assert result.status == Status.COMPLETED
296-
assert result.execution_count >= 1
297-
assert len(result.results) >= 1
260+
assert result.execution_count == 1
261+
assert len(result.results) == 1
298262
assert result.execution_time >= 0
299263

300264
# Verify agent was called
@@ -348,22 +312,72 @@ def test_swarm_builder_validation(mock_agents):
348312

349313
def test_swarm_handoff_functionality():
350314
"""Test swarm handoff functionality."""
351-
# Test handoff functionality - successful handoff during execution
315+
316+
# Create an agent that will hand off to another agent
317+
def create_handoff_agent(name, target_agent_name, response_text="Handing off"):
318+
"""Create a mock agent that performs handoffs."""
319+
agent = create_mock_agent(name, response_text, complete_after_calls=math.inf) # Never complete naturally
320+
agent._handoff_done = False # Track if handoff has been performed
321+
322+
def create_handoff_result():
323+
agent._call_count += 1
324+
# Perform handoff on first execution call (not setup calls)
325+
if not agent._handoff_done and agent._swarm_ref and hasattr(agent._swarm_ref.state, "completion_status"):
326+
target_node = agent._swarm_ref.nodes.get(target_agent_name)
327+
if target_node:
328+
agent._swarm_ref._handle_handoff(
329+
target_node, f"Handing off to {target_agent_name}", {"handoff_context": "test_data"}
330+
)
331+
agent._handoff_done = True
332+
333+
return AgentResult(
334+
message={"role": "assistant", "content": [{"text": response_text}]},
335+
stop_reason="end_turn",
336+
state={},
337+
metrics=Mock(
338+
accumulated_usage={"inputTokens": 5, "outputTokens": 10, "totalTokens": 15},
339+
accumulated_metrics={"latencyMs": 50.0},
340+
),
341+
)
342+
343+
agent.return_value = create_handoff_result()
344+
agent.__call__ = Mock(side_effect=create_handoff_result)
345+
346+
async def mock_stream_async(*args, **kwargs):
347+
result = create_handoff_result()
348+
yield {"result": result}
349+
350+
agent.stream_async = MagicMock(side_effect=mock_stream_async)
351+
return agent
352+
353+
# Create agents - first one hands off, second one completes
352354
handoff_agent = create_handoff_agent("handoff_agent", "completion_agent")
353-
completion_agent = create_mock_agent("completion_agent", complete_after_calls=1)
355+
completion_agent = create_mock_agent("completion_agent", "Task completed", complete_after_calls=1)
354356

355-
# Create a swarm with lower limits to avoid hitting max handoffs
356-
handoff_swarm = Swarm(nodes=[handoff_agent, completion_agent], max_handoffs=5, max_iterations=5)
357+
# Create a swarm with reasonable limits
358+
handoff_swarm = Swarm(nodes=[handoff_agent, completion_agent], max_handoffs=10, max_iterations=10)
357359
handoff_agent._swarm_ref = handoff_swarm
358360
completion_agent._swarm_ref = handoff_swarm
359361

360-
# Execute swarm - this will trigger handoff during execution (not when completed)
362+
# Execute swarm - this should hand off from first agent to second agent
361363
result = handoff_swarm("Test handoff during execution")
362-
# The handoff might still fail due to the complexity of the mock setup, but we've covered the handoff path
363-
# The important thing is that we've tested the handoff logic itself
364-
assert result.status in [Status.COMPLETED, Status.FAILED] # Either outcome is acceptable for coverage
365364

366-
# Test handoff when task is already completed (different path)
365+
# Verify the handoff occurred
366+
assert result.status == Status.COMPLETED
367+
assert result.execution_count == 2 # Both agents should have executed
368+
assert len(result.node_history) == 2
369+
370+
# Verify the handoff agent executed first
371+
assert result.node_history[0].node_id == "handoff_agent"
372+
373+
# Verify the completion agent executed after handoff
374+
assert result.node_history[1].node_id == "completion_agent"
375+
376+
# Verify both agents were called
377+
handoff_agent.stream_async.assert_called()
378+
completion_agent.stream_async.assert_called()
379+
380+
# Test handoff when task is already completed
367381
completed_swarm = Swarm(nodes=[handoff_agent, completion_agent])
368382
completed_swarm.state.completion_status = Status.COMPLETED
369383
completed_swarm._handle_handoff(completed_swarm.nodes["completion_agent"], "test message", {"key": "value"})

0 commit comments

Comments
 (0)