Skip to content

Commit 62f4050

Browse files
committed
feat(multiagent): Swarm - PR feedback
1 parent 062bc8b commit 62f4050

File tree

8 files changed

+118
-60
lines changed

8 files changed

+118
-60
lines changed

src/strands/multiagent/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ class MultiAgentBase(ABC):
8282
"""
8383

8484
@abstractmethod
85-
async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
86-
"""Execute task asynchronously."""
87-
raise NotImplementedError("execute_async not implemented")
85+
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
86+
"""Invoke asynchronously."""
87+
raise NotImplementedError("invoke_async not implemented")
8888

8989
@abstractmethod
9090
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
91-
"""Execute task synchronously."""
91+
"""Invoke synchronously."""
9292
raise NotImplementedError("__call__ not implemented")

src/strands/multiagent/graph.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def __init__(self) -> None:
140140

141141
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
142142
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
143+
# Check for duplicate node instances
144+
seen_instances = {id(node.executor) for node in self.nodes.values()}
145+
if id(executor) in seen_instances:
146+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
147+
143148
# Auto-generate node_id if not provided
144149
if node_id is None:
145150
node_id = getattr(executor, "id", None) or getattr(executor, "name", None) or f"node_{len(self.nodes)}"
@@ -242,24 +247,27 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
242247
"""Initialize Graph."""
243248
super().__init__()
244249

250+
# Validate nodes for duplicate instances
251+
self._validate_graph(nodes)
252+
245253
self.nodes = nodes
246254
self.edges = edges
247255
self.entry_points = entry_points
248256
self.state = GraphState()
249257
self.tracer = get_tracer()
250258

251259
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
252-
"""Execute task synchronously."""
260+
"""Invoke the graph synchronously."""
253261

254262
def execute() -> GraphResult:
255-
return asyncio.run(self.execute_async(task))
263+
return asyncio.run(self.invoke_async(task))
256264

257265
with ThreadPoolExecutor() as executor:
258266
future = executor.submit(execute)
259267
return future.result()
260268

261-
async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
262-
"""Execute the graph asynchronously."""
269+
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
270+
"""Invoke the graph asynchronously."""
263271
logger.debug("task=<%s> | starting graph execution", task)
264272

265273
# Initialize state
@@ -287,6 +295,15 @@ async def execute_async(self, task: str | list[ContentBlock], **kwargs: Any) ->
287295
self.state.execution_time = round((time.time() - start_time) * 1000)
288296
return self._build_result()
289297

298+
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
299+
"""Validate graph nodes for duplicate instances."""
300+
# Check for duplicate node instances
301+
seen_instances = set()
302+
for node in nodes.values():
303+
if id(node.executor) in seen_instances:
304+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
305+
seen_instances.add(id(node.executor))
306+
290307
async def _execute_graph(self) -> None:
291308
"""Unified execution flow with conditional routing."""
292309
ready_nodes = list(self.entry_points)
@@ -349,7 +366,7 @@ async def _execute_node(self, node: GraphNode) -> None:
349366

350367
# Execute based on node type and create unified NodeResult
351368
if isinstance(node.executor, MultiAgentBase):
352-
multi_agent_result = await node.executor.execute_async(node_input)
369+
multi_agent_result = await node.executor.invoke_async(node_input)
353370

354371
# Create NodeResult with MultiAgentResult directly
355372
node_result = NodeResult(

src/strands/multiagent/swarm.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
import asyncio
16+
import copy
1617
import json
1718
import logging
1819
import time
@@ -21,8 +22,9 @@
2122
from typing import Any, Callable, Tuple, cast
2223

2324
from ..agent import Agent, AgentResult
25+
from ..agent.state import AgentState
2426
from ..tools.decorator import tool
25-
from ..types.content import ContentBlock
27+
from ..types.content import ContentBlock, Messages
2628
from ..types.event_loop import Metrics, Usage
2729
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
2830

@@ -35,6 +37,14 @@ class SwarmNode:
3537

3638
node_id: str
3739
executor: Agent
40+
_initial_messages: Messages = field(default_factory=list, init=False)
41+
_initial_state: AgentState = field(default_factory=AgentState, init=False)
42+
43+
def __post_init__(self) -> None:
44+
"""Capture initial executor state after initialization."""
45+
# Deep copy the initial messages and state to preserve them
46+
self._initial_messages = copy.deepcopy(self.executor.messages)
47+
self._initial_state = AgentState(self.executor.state.get())
3848

3949
def __hash__(self) -> int:
4050
"""Return hash for SwarmNode based on node_id."""
@@ -54,6 +64,11 @@ def __repr__(self) -> str:
5464
"""Return detailed representation of SwarmNode."""
5565
return f"SwarmNode(node_id='{self.node_id}')"
5666

67+
def reset_executor_state(self) -> None:
68+
"""Reset SwarmNode executor state to initial state when swarm was created."""
69+
self.executor.messages = copy.deepcopy(self._initial_messages)
70+
self.executor.state = AgentState(self._initial_state.get())
71+
5772

5873
@dataclass
5974
class SharedContext:
@@ -218,35 +233,19 @@ def __init__(
218233
self._setup_swarm(nodes)
219234
self._inject_swarm_tools()
220235

221-
def __call__(self, task: str | list[ContentBlock]) -> SwarmResult:
222-
"""Execute task synchronously.
223-
224-
Args:
225-
task: The task to execute, either as a string or a list of ContentBlock objects
226-
for multi-modal content.
227-
228-
Returns:
229-
SwarmResult containing execution results and metrics.
230-
"""
236+
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
237+
"""Invoke the swarm synchronously."""
231238

232239
def execute() -> SwarmResult:
233-
return asyncio.run(self.execute_async(task))
240+
return asyncio.run(self.invoke_async(task))
234241

235242
with ThreadPoolExecutor() as executor:
236243
future = executor.submit(execute)
237244
return future.result()
238245

239-
async def execute_async(self, task: str | list[ContentBlock]) -> SwarmResult:
240-
"""Execute the swarm asynchronously.
241-
242-
Args:
243-
task: The task to execute, either as a string or a list of ContentBlock objects
244-
for multi-modal content.
245-
246-
Returns:
247-
SwarmResult containing execution results and metrics.
248-
"""
249-
logger.info("starting swarm execution")
246+
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
247+
"""Invoke the swarm asynchronously."""
248+
logger.debug("starting swarm execution")
250249

251250
# Initialize swarm state with configuration
252251
initial_node = next(iter(self.nodes.values())) # First SwarmNode
@@ -259,8 +258,8 @@ async def execute_async(self, task: str | list[ContentBlock]) -> SwarmResult:
259258

260259
start_time = time.time()
261260
try:
262-
logger.info("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
263-
logger.info(
261+
logger.debug("current_node=<%s> | starting swarm execution with node", self.state.current_node.node_id)
262+
logger.debug(
264263
"max_handoffs=<%d>, max_iterations=<%d>, timeout=<%s>s | swarm execution config",
265264
self.max_handoffs,
266265
self.max_iterations,
@@ -279,12 +278,15 @@ async def execute_async(self, task: str | list[ContentBlock]) -> SwarmResult:
279278

280279
def _setup_swarm(self, nodes: list[Agent]) -> None:
281280
"""Initialize swarm configuration."""
281+
# Validate nodes before setup
282+
self._validate_swarm(nodes)
283+
282284
# Validate agents have names and create SwarmNode objects
283285
for i, node in enumerate(nodes):
284286
if not node.name:
285287
node_id = f"node_{i}"
286288
node.name = node_id
287-
logger.info("node_id=<%s> | agent has no name, dynamically generating one", node_id)
289+
logger.debug("node_id=<%s> | agent has no name, dynamically generating one", node_id)
288290

289291
node_id = str(node.name)
290292

@@ -295,7 +297,16 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
295297
self.nodes[node_id] = SwarmNode(node_id=node_id, executor=node)
296298

297299
swarm_nodes = list(self.nodes.values())
298-
logger.info("nodes=<%s> | initialized swarm with nodes", [node.node_id for node in swarm_nodes])
300+
logger.debug("nodes=<%s> | initialized swarm with nodes", [node.node_id for node in swarm_nodes])
301+
302+
def _validate_swarm(self, nodes: list[Agent]) -> None:
303+
"""Validate swarm structure and nodes."""
304+
# Check for duplicate object instances
305+
seen_instances = set()
306+
for node in nodes:
307+
if id(node) in seen_instances:
308+
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
309+
seen_instances.add(id(node))
299310

300311
def _inject_swarm_tools(self) -> None:
301312
"""Add swarm coordination tools to each agent."""
@@ -324,7 +335,7 @@ def _inject_swarm_tools(self) -> None:
324335
# Use the agent's tool registry to process and register the tools
325336
node.executor.tool_registry.process_tools(swarm_tools)
326337

327-
logger.info(
338+
logger.debug(
328339
"tool_count=<%d>, node_count=<%d> | injected coordination tools into agents",
329340
len(swarm_tools),
330341
len(self.nodes),
@@ -388,7 +399,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
388399
"""Handle handoff to another agent."""
389400
# If task is already completed, don't allow further handoffs
390401
if self.state.completion_status != Status.EXECUTING:
391-
logger.info(
402+
logger.debug(
392403
"task_status=<%s> | ignoring handoff request - task already completed",
393404
self.state.completion_status,
394405
)
@@ -406,7 +417,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st
406417
for key, value in context.items():
407418
self.shared_context.add_context(previous_agent, key, value)
408419

409-
logger.info(
420+
logger.debug(
410421
"from_node=<%s>, to_node=<%s> | handed off from agent to agent",
411422
previous_agent.node_id,
412423
target_node.node_id,
@@ -416,7 +427,7 @@ def _handle_completion(self) -> None:
416427
"""Handle task completion."""
417428
self.state.completion_status = Status.COMPLETED
418429

419-
logger.info("swarm task completed")
430+
logger.debug("swarm task completed")
420431

421432
def _build_node_input(self, target_node: SwarmNode) -> str:
422433
"""Build input text for a node based on shared context and handoffs.
@@ -499,7 +510,7 @@ async def _execute_swarm(self) -> None:
499510
while True:
500511
if self.state.completion_status != Status.EXECUTING:
501512
reason = f"Completion status is: {self.state.completion_status}"
502-
logger.info("reason=<%s> | stopping execution", reason)
513+
logger.debug("reason=<%s> | stopping execution", reason)
503514
break
504515

505516
should_continue, reason = self.state.should_continue(
@@ -511,7 +522,7 @@ async def _execute_swarm(self) -> None:
511522
)
512523
if not should_continue:
513524
self.state.completion_status = Status.FAILED
514-
logger.info("reason=<%s> | stopping execution", reason)
525+
logger.debug("reason=<%s> | stopping execution", reason)
515526
break
516527

517528
# Get current node
@@ -521,13 +532,14 @@ async def _execute_swarm(self) -> None:
521532
self.state.completion_status = Status.FAILED
522533
break
523534

524-
logger.info(
535+
logger.debug(
525536
"current_node=<%s>, iteration=<%d> | executing node",
526537
current_node.node_id,
527538
len(self.state.node_history) + 1,
528539
)
529540

530541
# Execute node with timeout protection
542+
# TODO: Implement cancellation token to stop _execute_node from continuing
531543
try:
532544
await asyncio.wait_for(
533545
self._execute_node(current_node, self.state.task),
@@ -536,11 +548,11 @@ async def _execute_swarm(self) -> None:
536548

537549
self.state.node_history.append(current_node)
538550

539-
logger.info("node=<%s> | node execution completed", current_node.node_id)
551+
logger.debug("node=<%s> | node execution completed", current_node.node_id)
540552

541553
# Immediate check for completion after node execution
542554
if self.state.completion_status != Status.EXECUTING:
543-
logger.info("status=<%s> | task completed with status", self.state.completion_status) # type: ignore[unreachable]
555+
logger.debug("status=<%s> | task completed with status", self.state.completion_status) # type: ignore[unreachable]
544556
break
545557

546558
except asyncio.TimeoutError:
@@ -562,8 +574,8 @@ async def _execute_swarm(self) -> None:
562574
self.state.completion_status = Status.FAILED
563575

564576
elapsed_time = time.time() - self.state.start_time
565-
logger.info("status=<%s> | swarm execution completed", self.state.completion_status)
566-
logger.info(
577+
logger.debug("status=<%s> | swarm execution completed", self.state.completion_status)
578+
logger.debug(
567579
"node_history_length=<%d>, time=<%s>s | metrics",
568580
len(self.state.node_history),
569581
f"{elapsed_time:.2f}",
@@ -588,7 +600,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -
588600

589601
# Execute node
590602
result = None
591-
node.executor.messages = [] # Reset agent's messages to avoid polluting context
603+
node.reset_executor_state()
592604
async for event in node.executor.stream_async(node_input):
593605
if "result" in event:
594606
result = cast(AgentResult, event["result"])

tests/strands/multiagent/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class IncompleteMultiAgent(MultiAgentBase):
138138

139139
# Test that complete implementations can be instantiated
140140
class CompleteMultiAgent(MultiAgentBase):
141-
async def execute_async(self, task: str) -> MultiAgentResult:
141+
async def invoke_async(self, task: str) -> MultiAgentResult:
142142
return MultiAgentResult(results={})
143143

144144
def __call__(self, task: str) -> MultiAgentResult:

0 commit comments

Comments
 (0)