diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f81da9096..a6c901d28 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -9,6 +9,7 @@ from typing import Union from ..agent import AgentResult +from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage @@ -75,13 +76,11 @@ class MultiAgentBase(ABC): """ @abstractmethod - # TODO: for task - multi-modal input (Message), list of messages - async def execute_async(self, task: str) -> MultiAgentResult: + async def execute_async(self, task: str | list[ContentBlock]) -> MultiAgentResult: """Execute task asynchronously.""" raise NotImplementedError("execute_async not implemented") @abstractmethod - # TODO: for task - multi-modal input (Message), list of messages - def execute(self, task: str) -> MultiAgentResult: + def execute(self, task: str | list[ContentBlock]) -> MultiAgentResult: """Execute task synchronously.""" raise NotImplementedError("execute not implemented") diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 4795dfbfd..0a7641010 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -22,6 +22,7 @@ from typing import Any, Callable, Tuple, cast from ..agent import Agent, AgentResult +from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status @@ -42,12 +43,14 @@ class GraphState: Entry point nodes receive this task as their input if they have no dependencies. """ + # Task (with default empty string) + task: str | list[ContentBlock] = "" + # Execution state status: Status = Status.PENDING completed_nodes: set["GraphNode"] = field(default_factory=set) failed_nodes: set["GraphNode"] = field(default_factory=set) execution_order: list["GraphNode"] = field(default_factory=list) - task: str = "" # Results results: dict[str, NodeResult] = field(default_factory=dict) @@ -247,7 +250,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi self.entry_points = entry_points self.state = GraphState() - def execute(self, task: str) -> GraphResult: + def execute(self, task: str | list[ContentBlock]) -> GraphResult: """Execute task synchronously.""" def execute() -> GraphResult: @@ -257,7 +260,7 @@ def execute() -> GraphResult: future = executor.submit(execute) return future.result() - async def execute_async(self, task: str) -> GraphResult: + async def execute_async(self, task: str | list[ContentBlock]) -> GraphResult: """Execute the graph asynchronously.""" logger.debug("task=<%s> | starting graph execution", task) @@ -435,8 +438,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None: self.state.accumulated_metrics["latencyMs"] += node_result.accumulated_metrics.get("latencyMs", 0) self.state.execution_count += node_result.execution_count - def _build_node_input(self, node: GraphNode) -> str: - """Build input text for a node based on dependency outputs.""" + def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: + """Build input for a node based on dependency outputs.""" # Get satisfied dependencies dependency_results = {} for edge in self.edges: @@ -449,21 +452,36 @@ def _build_node_input(self, node: GraphNode) -> str: dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id] if not dependency_results: - return self.state.task + # No dependencies - return task as ContentBlocks + if isinstance(self.state.task, str): + return [ContentBlock(text=self.state.task)] + else: + return self.state.task # Combine task with dependency outputs - input_parts = [f"Original Task: {self.state.task}", "\nInputs from previous nodes:"] + node_input = [] + + # Add original task + if isinstance(self.state.task, str): + node_input.append(ContentBlock(text=f"Original Task: {self.state.task}")) + else: + # Add task content blocks with a prefix + node_input.append(ContentBlock(text="Original Task:")) + node_input.extend(self.state.task) + + # Add dependency outputs + node_input.append(ContentBlock(text="\nInputs from previous nodes:")) for dep_id, node_result in dependency_results.items(): - input_parts.append(f"\nFrom {dep_id}:") + node_input.append(ContentBlock(text=f"\nFrom {dep_id}:")) # Get all agent results from this node (flattened if nested) agent_results = node_result.get_agent_results() for result in agent_results: agent_name = getattr(result, "agent_name", "Agent") result_text = str(result) - input_parts.append(f" - {agent_name}: {result_text}") + node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}")) - return "\n".join(input_parts) + return node_input def _build_result(self) -> GraphResult: """Build graph result from current state.""" diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 38eb3af1c..99700c964 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -273,10 +273,10 @@ async def test_graph_edge_cases(): builder.add_node(entry_agent, "entry_only") graph = builder.build() - result = await graph.execute_async("Original task") + result = await graph.execute_async([{"text": "Original task"}]) # Verify entry node was called with original task - entry_agent.stream_async.assert_called_once_with("Original task") + entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}]) assert result.status == Status.COMPLETED diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index 64d5aae53..2e5a5e626 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -2,6 +2,7 @@ from strands import Agent, tool from strands.multiagent.graph import GraphBuilder +from strands.types.content import ContentBlock @tool @@ -23,7 +24,6 @@ def math_agent(): model="us.amazon.nova-pro-v1:0", system_prompt="You are a mathematical assistant. Always provide clear, step-by-step calculations.", tools=[calculate_sum, multiply_numbers], - load_tools_from_directory=False, ) @@ -33,7 +33,6 @@ def analysis_agent(): return Agent( model="us.amazon.nova-pro-v1:0", system_prompt="You are a data analysis expert. Provide insights and interpretations of numerical results.", - load_tools_from_directory=False, ) @@ -43,7 +42,6 @@ def summary_agent(): return Agent( model="us.amazon.nova-lite-v1:0", system_prompt="You are a summarization expert. Create concise, clear summaries of complex information.", - load_tools_from_directory=False, ) @@ -53,7 +51,16 @@ def validation_agent(): return Agent( model="us.amazon.nova-pro-v1:0", system_prompt="You are a validation expert. Check results for accuracy and completeness.", - load_tools_from_directory=False, + ) + + +@pytest.fixture +def image_analysis_agent(): + """Create an agent specialized in image analysis.""" + return Agent( + system_prompt=( + "You are an image analysis expert. Describe what you see in images and provide detailed analysis." + ) ) @@ -74,7 +81,7 @@ def nested_computation_graph(math_agent, analysis_agent): @pytest.mark.asyncio -async def test_graph_execution(math_agent, summary_agent, validation_agent, nested_computation_graph): +async def test_graph_execution_with_string(math_agent, summary_agent, validation_agent, nested_computation_graph): # Define conditional functions def should_validate(state): """Condition to determine if validation should run.""" @@ -131,3 +138,43 @@ def proceed_to_second_summary(state): # Verify nested graph execution nested_result = result.results["computation_subgraph"].result assert nested_result.status.value == "completed" + + +@pytest.mark.asyncio +async def test_graph_execution_with_image(image_analysis_agent, summary_agent, yellow_img): + """Test graph execution with multi-modal image input.""" + builder = GraphBuilder() + + # Add agents to graph + builder.add_node(image_analysis_agent, "image_analyzer") + builder.add_node(summary_agent, "summarizer") + + # Connect them sequentially + builder.add_edge("image_analyzer", "summarizer") + builder.set_entry_point("image_analyzer") + + graph = builder.build() + + # Create content blocks with text and image + content_blocks: list[ContentBlock] = [ + {"text": "Analyze this image and describe what you see:"}, + {"image": {"format": "png", "source": {"bytes": yellow_img}}}, + ] + + # Execute the graph with multi-modal input + result = await graph.execute_async(content_blocks) + + # Verify results + assert result.status.value == "completed" + assert result.total_nodes == 2 + assert result.completed_nodes == 2 + assert result.failed_nodes == 0 + assert len(result.results) == 2 + + # Verify execution order + execution_order_ids = [node.node_id for node in result.execution_order] + assert execution_order_ids == ["image_analyzer", "summarizer"] + + # Verify both nodes completed + assert "image_analyzer" in result.results + assert "summarizer" in result.results