Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Union

from ..agent import AgentResult
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage


Expand Down Expand Up @@ -75,13 +76,11 @@ class MultiAgentBase(ABC):
"""

@abstractmethod
# TODO: for task - multi-modal input (Message), list of messages
async def execute_async(self, task: str) -> MultiAgentResult:
async def execute_async(self, task: str | list[ContentBlock]) -> MultiAgentResult:
"""Execute task asynchronously."""
raise NotImplementedError("execute_async not implemented")

@abstractmethod
# TODO: for task - multi-modal input (Message), list of messages
def execute(self, task: str) -> MultiAgentResult:
def execute(self, task: str | list[ContentBlock]) -> MultiAgentResult:
"""Execute task synchronously."""
raise NotImplementedError("execute not implemented")
38 changes: 28 additions & 10 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Callable, Tuple, cast

from ..agent import Agent, AgentResult
from ..types.content import ContentBlock
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status

Expand All @@ -42,12 +43,14 @@ class GraphState:
Entry point nodes receive this task as their input if they have no dependencies.
"""

# Task (with default empty string)
task: str | list[ContentBlock] = ""

# Execution state
status: Status = Status.PENDING
completed_nodes: set["GraphNode"] = field(default_factory=set)
failed_nodes: set["GraphNode"] = field(default_factory=set)
execution_order: list["GraphNode"] = field(default_factory=list)
task: str = ""

# Results
results: dict[str, NodeResult] = field(default_factory=dict)
Expand Down Expand Up @@ -247,7 +250,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
self.entry_points = entry_points
self.state = GraphState()

def execute(self, task: str) -> GraphResult:
def execute(self, task: str | list[ContentBlock]) -> GraphResult:
"""Execute task synchronously."""

def execute() -> GraphResult:
Expand All @@ -257,7 +260,7 @@ def execute() -> GraphResult:
future = executor.submit(execute)
return future.result()

async def execute_async(self, task: str) -> GraphResult:
async def execute_async(self, task: str | list[ContentBlock]) -> GraphResult:
"""Execute the graph asynchronously."""
logger.debug("task=<%s> | starting graph execution", task)

Expand Down Expand Up @@ -435,8 +438,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
self.state.accumulated_metrics["latencyMs"] += node_result.accumulated_metrics.get("latencyMs", 0)
self.state.execution_count += node_result.execution_count

def _build_node_input(self, node: GraphNode) -> str:
"""Build input text for a node based on dependency outputs."""
def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
"""Build input for a node based on dependency outputs."""
# Get satisfied dependencies
dependency_results = {}
for edge in self.edges:
Expand All @@ -449,21 +452,36 @@ def _build_node_input(self, node: GraphNode) -> str:
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]

if not dependency_results:
return self.state.task
# No dependencies - return task as ContentBlocks
if isinstance(self.state.task, str):
return [ContentBlock(text=self.state.task)]
else:
return self.state.task

# Combine task with dependency outputs
input_parts = [f"Original Task: {self.state.task}", "\nInputs from previous nodes:"]
node_input = []

# Add original task
if isinstance(self.state.task, str):
node_input.append(ContentBlock(text=f"Original Task: {self.state.task}"))
else:
# Add task content blocks with a prefix
node_input.append(ContentBlock(text="Original Task:"))
node_input.extend(self.state.task)

# Add dependency outputs
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))

for dep_id, node_result in dependency_results.items():
input_parts.append(f"\nFrom {dep_id}:")
node_input.append(ContentBlock(text=f"\nFrom {dep_id}:"))
# Get all agent results from this node (flattened if nested)
agent_results = node_result.get_agent_results()
for result in agent_results:
agent_name = getattr(result, "agent_name", "Agent")
result_text = str(result)
input_parts.append(f" - {agent_name}: {result_text}")
node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}"))

return "\n".join(input_parts)
return node_input

def _build_result(self) -> GraphResult:
"""Build graph result from current state."""
Expand Down
4 changes: 2 additions & 2 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ async def test_graph_edge_cases():
builder.add_node(entry_agent, "entry_only")
graph = builder.build()

result = await graph.execute_async("Original task")
result = await graph.execute_async([{"text": "Original task"}])

# Verify entry node was called with original task
entry_agent.stream_async.assert_called_once_with("Original task")
entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}])
assert result.status == Status.COMPLETED


Expand Down
57 changes: 52 additions & 5 deletions tests_integ/test_multiagent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from strands import Agent, tool
from strands.multiagent.graph import GraphBuilder
from strands.types.content import ContentBlock


@tool
Expand All @@ -23,7 +24,6 @@ def math_agent():
model="us.amazon.nova-pro-v1:0",
system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.",
tools=[calculate_sum, multiply_numbers],
load_tools_from_directory=False,
)


Expand All @@ -33,7 +33,6 @@ def analysis_agent():
return Agent(
model="us.amazon.nova-pro-v1:0",
system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.",
load_tools_from_directory=False,
)


Expand All @@ -43,7 +42,6 @@ def summary_agent():
return Agent(
model="us.amazon.nova-lite-v1:0",
system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.",
load_tools_from_directory=False,
)


Expand All @@ -53,7 +51,16 @@ def validation_agent():
return Agent(
model="us.amazon.nova-pro-v1:0",
system_prompt="You are a validation expert. Check results for accuracy and completeness.",
load_tools_from_directory=False,
)


@pytest.fixture
def image_analysis_agent():
"""Create an agent specialized in image analysis."""
return Agent(
system_prompt=(
"You are an image analysis expert. Describe what you see in images and provide detailed analysis."
)
)


Expand All @@ -74,7 +81,7 @@ def nested_computation_graph(math_agent, analysis_agent):


@pytest.mark.asyncio
async def test_graph_execution(math_agent, summary_agent, validation_agent, nested_computation_graph):
async def test_graph_execution_with_string(math_agent, summary_agent, validation_agent, nested_computation_graph):
# Define conditional functions
def should_validate(state):
"""Condition to determine if validation should run."""
Expand Down Expand Up @@ -131,3 +138,43 @@ def proceed_to_second_summary(state):
# Verify nested graph execution
nested_result = result.results["computation_subgraph"].result
assert nested_result.status.value == "completed"


@pytest.mark.asyncio
async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img):
"""Test graph execution with multi-modal image input."""
builder = GraphBuilder()

# Add agents to graph
builder.add_node(image_analysis_agent, "image_analyzer")
builder.add_node(summary_agent, "summarizer")

# Connect them sequentially
builder.add_edge("image_analyzer", "summarizer")
builder.set_entry_point("image_analyzer")

graph = builder.build()

# Create content blocks with text and image
content_blocks: list[ContentBlock] = [
{"text": "Analyze this image and describe what you see:"},
{"image": {"format": "png", "source": {"bytes": yellow_img}}},
]

# Execute the graph with multi-modal input
result = await graph.execute_async(content_blocks)

# Verify results
assert result.status.value == "completed"
assert result.total_nodes == 2
assert result.completed_nodes == 2
assert result.failed_nodes == 0
assert len(result.results) == 2

# Verify execution order
execution_order_ids = [node.node_id for node in result.execution_order]
assert execution_order_ids == ["image_analyzer", "summarizer"]

# Verify both nodes completed
assert "image_analyzer" in result.results
assert "summarizer" in result.results
Loading