diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index eb9bc4dd9..b63b4329e 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, Sequence, cast from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics @@ -46,3 +46,40 @@ def __str__(self) -> str: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentResult": + """Rehydrate an AgentResult from persisted JSON. + + Args: + data: Dictionary containing the serialized AgentResult data + + Returns: + AgentResult instance + + Raises: + TypeError: If the data format is invalid + """ + if data.get("type") != "agent_result": + raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") + + message = cast(Message, data.get("message")) + stop_reason = cast(StopReason, data.get("stop_reason")) + + return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + + @classmethod + def to_dict(cls, data: "AgentResult") -> dict[str, Any]: + """Convert an AgentResult to JSON-serializable dictionary. + + Args: + data: AgentResult instance to serialize + + Returns: + Dictionary containing serialized AgentResult data + """ + return { + "type": "agent_result", + "message": data.message, + "stop_reason": data.stop_reason, + } diff --git a/src/strands/experimental/hooks/multiagent/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..83a62398b --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -0,0 +1,21 @@ +"""Multi-agent session management for persistent execution. + +This package provides session persistence capabilities for multi-agent orchestrators, +enabling resumable execution after interruptions or failures. +""" + +from .events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) + +__all__ = [ + "AfterMultiAgentInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", +] diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py new file mode 100644 index 000000000..257cf57f0 --- /dev/null +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -0,0 +1,93 @@ +"""Multi-agent execution lifecycle events for hook system integration. + +These events are fired by orchestrators (Graph/Swarm) at key points so +hooks can persist, monitor, or debug execution. No intermediate state model +is used—hooks read from the orchestrator directly. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from ....hooks import BaseHookEvent + +if TYPE_CHECKING: + from ....multiagent.base import MultiAgentBase + + +@dataclass +class MultiAgentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class BeforeNodeCallEvent(BaseHookEvent): + """Event triggered before individual node execution completes. This event corresponds to the After event. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterNodeCallEvent(BaseHookEvent): + """Event triggered after individual node execution completes. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + node_id: str + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered before orchestrator execution completes. This event corresponds to the After event. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + +@dataclass +class AfterMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user passes in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0dbd85d81..bfe4b3b3f 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -4,6 +4,7 @@ """ import asyncio +import logging import warnings from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -15,6 +16,8 @@ from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage +logger = logging.getLogger(__name__) + class Status(Enum): """Execution status for both graphs and nodes.""" @@ -59,6 +62,67 @@ def get_agent_results(self) -> list[AgentResult]: flattened.extend(nested_node_result.get_agent_results()) return flattened + def to_dict(self) -> dict[str, Any]: + """Convert NodeResult to JSON-serializable dict, ignoring state field.""" + if isinstance(self.result, Exception): + result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} + elif isinstance(self.result, AgentResult): + # Serialize AgentResult without state field + result_data = AgentResult.to_dict(self.result) + elif isinstance(self.result, MultiAgentResult): + result_data = self.result.to_dict() + else: + raise TypeError(f"Unsupported NodeResult.result type for serialization: {type(self.result).__name__}") + + return { + "result": result_data, + "execution_time": self.execution_time, + "status": self.status.value, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeResult": + """Rehydrate a NodeResult from persisted JSON.""" + if "result" not in data: + raise TypeError("NodeResult.from_dict: missing 'result'") + raw = data["result"] + + result: Union[AgentResult, "MultiAgentResult", Exception] + if isinstance(raw, dict) and raw.get("type") == "agent_result": + result = AgentResult.from_dict(raw) + elif isinstance(raw, dict) and raw.get("type") == "exception": + result = Exception(str(raw.get("message", "node failed"))) + elif isinstance(raw, dict) and raw.get("type") == "multiagent_result": + result = MultiAgentResult.from_dict(raw) + else: + raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") + + usage_data = data.get("accumulated_usage", {}) + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + + metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) + + return cls( + result=result, + execution_time=int(data.get("execution_time", 0)), + status=Status(data.get("status", "pending")), + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + ) + @dataclass class MultiAgentResult: @@ -76,6 +140,49 @@ class MultiAgentResult: execution_count: int = 0 execution_time: int = 0 + def to_dict(self) -> dict[str, Any]: + """Convert MultiAgentResult to JSON-serializable dict.""" + return { + "type": "multiagent_result", + "status": self.status.value, + "results": {k: v.to_dict() for k, v in self.results.items()}, + "accumulated_usage": dict(self.accumulated_usage), + "accumulated_metrics": dict(self.accumulated_metrics), + "execution_count": self.execution_count, + "execution_time": self.execution_time, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": + """Rehydrate a MultiAgentResult from persisted JSON.""" + if data.get("type") != "multiagent_result": + raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}") + + results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} + usage_data = data.get("accumulated_usage", {}) + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + + metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) + + multiagent_result = cls( + status=Status(data.get("status", Status.PENDING.value)), + results=results, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + execution_time=int(data.get("execution_time", 0)), + ) + return multiagent_result + class MultiAgentBase(ABC): """Base class for multi-agent helpers. @@ -122,3 +229,11 @@ def execute() -> MultiAgentResult: with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() + + def serialize_state(self) -> dict[str, Any]: + """Return a JSON-serializable snapshot of the orchestrator state.""" + raise NotImplementedError + + def deserialize_state(self, payload: dict[str, Any]) -> None: + """Restore orchestrator state from a session dict.""" + raise NotImplementedError diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 1dbbfc3af..278b354df 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -26,6 +26,14 @@ from ..agent import Agent from ..agent.state import AgentState +from ..experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from ..hooks import HookProvider, HookRegistry +from ..session import SessionManager from ..telemetry import get_tracer from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage @@ -217,6 +225,10 @@ def __init__(self) -> None: self._node_timeout: Optional[float] = None self._reset_on_revisit: bool = False + # session manager + self._session_manager: Optional[SessionManager] = None + self._hooks: Optional[list[HookProvider]] = None + def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" _validate_node_executor(executor, self.nodes) @@ -306,6 +318,24 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder": self._node_timeout = timeout return self + def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder": + """Set session manager for the graph. + + Args: + session_manager: SessionManager instance + """ + self._session_manager = session_manager + return self + + def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder": + """Set hook providers for the graph. + + Args: + hooks: Customer hooks user passes in + """ + self._hooks = hooks + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -331,6 +361,8 @@ def build(self) -> "Graph": execution_timeout=self._execution_timeout, node_timeout=self._node_timeout, reset_on_revisit=self._reset_on_revisit, + session_manager=self._session_manager, + hooks=self._hooks, ) def _validate_graph(self) -> None: @@ -358,6 +390,8 @@ def __init__( execution_timeout: Optional[float] = None, node_timeout: Optional[float] = None, reset_on_revisit: bool = False, + session_manager: Optional[SessionManager] = None, + hooks: Optional[list[HookProvider]] = None, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -369,6 +403,8 @@ def __init__( execution_timeout: Total execution timeout in seconds (default: None - no limit) node_timeout: Individual node timeout in seconds (default: None - no limit) reset_on_revisit: Whether to reset node state when revisited (default: False) + session_manager: Optional session manager for persistence + hooks: Optional custom hook providers for registry. """ super().__init__() @@ -384,6 +420,18 @@ def __init__( self.reset_on_revisit = reset_on_revisit self.state = GraphState() self.tracer = get_tracer() + self.session_manager = session_manager + self.hooks = HookRegistry() + if self.session_manager is not None: + self.hooks.add_hook(self.session_manager) + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + # Resume flags + self._resume_from_session = False + self._resume_next_nodes: list[GraphNode] = [] + + self.hooks.invoke_callbacks(MultiAgentInitializedEvent(source=self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -423,17 +471,20 @@ async def invoke_async( logger.debug("task=<%s> | starting graph execution", task) - # Initialize state - start_time = time.time() - self.state = GraphState( - status=Status.EXECUTING, - task=task, - total_nodes=len(self.nodes), - edges=[(edge.from_node, edge.to_node) for edge in self.edges], - entry_points=list(self.entry_points), - start_time=start_time, - ) - + if not self._resume_from_session: + start_time = time.time() + # Initialize state + self.state = GraphState( + status=Status.EXECUTING, + task=task, + total_nodes=len(self.nodes), + edges=[(edge.from_node, edge.to_node) for edge in self.edges], + entry_points=list(self.entry_points), + start_time=start_time, + ) + else: + self.state.status = Status.EXECUTING + self.state.start_time = time.time() span = self.tracer.start_multiagent_span(task, "graph") with trace_api.use_span(span, end_on_exit=True): try: @@ -459,7 +510,12 @@ async def invoke_async( self.state.status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time = round((time.time() - self.state.start_time) * 1000) + self.hooks.invoke_callbacks( + AfterMultiAgentInvocationEvent(source=self, invocation_state=invocation_state) + ) + self._resume_from_session = False + self._resume_next_nodes.clear() return self._build_result() def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: @@ -476,7 +532,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" - ready_nodes = list(self.entry_points) + ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -512,8 +568,11 @@ def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["G return newly_ready def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool: - """Check if a node is ready considering conditional edges.""" - # Get incoming edges to this node + """Check if a node is ready considering conditional edges. + + A node is ready iff ALL incoming deps are completed AND their edge conditions pass, + and at least one of those deps was completed in THIS batch (so it's newly ready). + """ incoming_edges = [edge for edge in self.edges if edge.to_node == node] # Check if at least one incoming edge condition is satisfied @@ -532,13 +591,37 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" - # Reset the node's state if reset_on_revisit is enabled and it's being revisited + # Reset the node's state if reset_on_revisit is enabled, and it's being revisited + + async def _record_failure(exception: Exception) -> None: + execution_time = round((time.time() - start_time) * 1000) + fail_result = NodeResult( + result=exception, + execution_time=execution_time, + status=Status.FAILED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=execution_time), + execution_count=1, + ) + + node.execution_status = Status.FAILED + node.result = fail_result + node.execution_time = execution_time + self.state.failed_nodes.add(node) + self.state.results[node.node_id] = fail_result + self.hooks.invoke_callbacks( + AfterNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) + + self.hooks.invoke_callbacks( + BeforeNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) + if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() # Remove from completed nodes since we're re-executing it self.state.completed_nodes.remove(node) - node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) @@ -547,109 +630,90 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Build node input from satisfied dependencies node_input = self._build_node_input(node) - # Execute with timeout protection (only if node_timeout is set) - try: - # Execute based on node type and create unified NodeResult - if isinstance(node.executor, MultiAgentBase): - if self.node_timeout is not None: - multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input, invocation_state), - timeout=self.node_timeout, - ) - else: - multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) - - # Create NodeResult with MultiAgentResult directly - node_result = NodeResult( - result=multi_agent_result, # type is MultiAgentResult - execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, - accumulated_usage=multi_agent_result.accumulated_usage, - accumulated_metrics=multi_agent_result.accumulated_metrics, - execution_count=multi_agent_result.execution_count, + # Execute based on node type and create unified NodeResult + if isinstance(node.executor, MultiAgentBase): + if self.node_timeout is not None: + multiagent_result = await asyncio.wait_for( + node.executor.invoke_async(node_input, invocation_state), + timeout=self.node_timeout, ) + else: + multiagent_result = await node.executor.invoke_async(node_input, invocation_state) + + # Create NodeResult with MultiAgentResult directly + node_result = NodeResult( + result=multiagent_result, # type is MultiAgentResult + execution_time=multiagent_result.execution_time, + status=Status.COMPLETED, + accumulated_usage=multiagent_result.accumulated_usage, + accumulated_metrics=multiagent_result.accumulated_metrics, + execution_count=multiagent_result.execution_count, + ) - elif isinstance(node.executor, Agent): - if self.node_timeout is not None: - agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, invocation_state=invocation_state), - timeout=self.node_timeout, - ) - else: - agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) - - if agent_response.stop_reason == "interrupt": - node.executor.messages.pop() # remove interrupted tool use message - node.executor._interrupt_state.deactivate() - - raise RuntimeError( - "user raised interrupt from agent | interrupts are not yet supported in graphs" - ) - - # Extract metrics from agent response - usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics = Metrics(latencyMs=0) - if hasattr(agent_response, "metrics") and agent_response.metrics: - if hasattr(agent_response.metrics, "accumulated_usage"): - usage = agent_response.metrics.accumulated_usage - if hasattr(agent_response.metrics, "accumulated_metrics"): - metrics = agent_response.metrics.accumulated_metrics - - node_result = NodeResult( - result=agent_response, # type is AgentResult - execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, - accumulated_usage=usage, - accumulated_metrics=metrics, - execution_count=1, + elif isinstance(node.executor, Agent): + if self.node_timeout is not None: + agent_response = await asyncio.wait_for( + node.executor.invoke_async(node_input, invocation_state=invocation_state), + timeout=self.node_timeout, ) else: - raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - - except asyncio.TimeoutError: - timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out after timeout", - node.node_id, - self.node_timeout, + agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + + if agent_response.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") + + # Extract metrics from agent response + usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + metrics = Metrics(latencyMs=0) + if hasattr(agent_response, "metrics") and agent_response.metrics: + if hasattr(agent_response.metrics, "accumulated_usage"): + usage = agent_response.metrics.accumulated_usage + if hasattr(agent_response.metrics, "accumulated_metrics"): + metrics = agent_response.metrics.accumulated_metrics + + node_result = NodeResult( + result=agent_response, # type is AgentResult + execution_time=round((time.time() - start_time) * 1000), + status=Status.COMPLETED, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=1, ) - raise Exception(timeout_msg) from None + else: + raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - # Mark as completed node.execution_status = Status.COMPLETED node.result = node_result node.execution_time = node_result.execution_time self.state.completed_nodes.add(node) self.state.results[node.node_id] = node_result self.state.execution_order.append(node) - # Accumulate metrics self._accumulate_metrics(node_result) + self.hooks.invoke_callbacks( + AfterNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time ) - except Exception as e: - logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e) - execution_time = round((time.time() - start_time) * 1000) - - # Create a NodeResult for the failed node - node_result = NodeResult( - result=e, # Store exception as result - execution_time=execution_time, - status=Status.FAILED, - accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), - accumulated_metrics=Metrics(latencyMs=execution_time), - execution_count=1, + except asyncio.TimeoutError: + timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" + logger.exception( + "node=<%s>, timeout=<%s>s | node execution timed out after timeout", + node.node_id, + self.node_timeout, ) + await _record_failure(asyncio.TimeoutError(timeout_msg)) + raise - node.execution_status = Status.FAILED - node.result = node_result - node.execution_time = execution_time - self.state.failed_nodes.add(node) - self.state.results[node.node_id] = node_result # Store in results for consistency - + except Exception as e: + logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e) + await _record_failure(e) raise def _accumulate_metrics(self, node_result: NodeResult) -> None: @@ -696,7 +760,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return self.state.task # Combine task with dependency outputs - node_input = [] + node_input: list[ContentBlock] = [] # Add original task if isinstance(self.state.task, str): @@ -709,15 +773,11 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: # Add dependency outputs node_input.append(ContentBlock(text="\nInputs from previous nodes:")) - for dep_id, node_result in dependency_results.items(): + for dep_id, prev_result in dependency_results.items(): node_input.append(ContentBlock(text=f"\nFrom {dep_id}:")) - # Get all agent results from this node (flattened if nested) - agent_results = node_result.get_agent_results() - for result in agent_results: - agent_name = getattr(result, "agent_name", "Agent") - result_text = str(result) - node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}")) - + for agent_result in prev_result.get_agent_results(): + agent_name = getattr(agent_result, "agent_name", "Agent") + node_input.append(ContentBlock(text=f" - {agent_name}: {agent_result}")) return node_input def _build_result(self) -> GraphResult: @@ -736,3 +796,131 @@ def _build_result(self) -> GraphResult: edges=self.state.edges, entry_points=self.state.entry_points, ) + + # Persistence Helper Functions + + def _from_dict(self, payload: dict[str, Any]) -> None: + status_raw = payload.get("status", "pending") + self.state.status = Status(status_raw) + + # Hydrate completed nodes & results + raw_results = payload.get("node_results") or {} + results: dict[str, NodeResult] = {} + for node_id, entry in raw_results.items(): + if node_id not in self.nodes: + continue + try: + results[node_id] = NodeResult.from_dict(entry) + except Exception: + logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) + raise + self.state.results = results + + # Restore completed nodes from persisted data + completed_node_ids = payload.get("completed_nodes") or [] + self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + + # Execution order (only nodes that still exist) + order_node_ids = payload.get("execution_order") or [] + self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes] + + # Task + self.state.task = payload.get("current_task", self.state.task) + + def _to_dict(self) -> dict[str, Any]: + status_str = getattr(self.state.status, "value", str(self.state.status)) + next_nodes = [n.node_id for n in self._compute_ready_nodes_for_resume()] + return { + "type": "graph", + "status": status_str, + "completed_nodes": [n.node_id for n in self.state.completed_nodes], + "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, + "next_nodes_to_execute": next_nodes, + "current_task": self.state.task, + "execution_order": [n.node_id for n in self.state.execution_order], + } + + def _map_node_ids(self, node_ids: list[str] | None) -> list[GraphNode]: + if not node_ids: + return [] + mapped_nodes = [] + for node_id in node_ids: + node = self.nodes.get(node_id) + if node: + mapped_nodes.append(node) + return mapped_nodes + + def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: + ready_nodes: list[GraphNode] = [] + completed_nodes = set(self.state.completed_nodes) + for node in self.nodes.values(): + if node in completed_nodes: + continue + incoming = [e for e in self.edges if e.to_node is node] + if not incoming: + ready_nodes.append(node) + elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): + ready_nodes.append(node) + + return ready_nodes + + def serialize_state(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state. + + Returns: + Dictionary containing the current graph state + """ + return self._to_dict() + + def deserialize_state(self, payload: dict) -> None: + """Restore orchestrator state from a session dict and prepare for execution. + + This method handles two scenarios: + 1. If the persisted status is COMPLETED, resets all nodes and graph state + to allow re-execution from the beginning. + 2. Otherwise, restores the persisted state and prepares to resume execution + from the next ready nodes. + + Args: + payload: Dictionary containing persisted state data including status, + completed nodes, results, and next nodes to execute. + """ + if not payload.get("next_nodes_to_execute"): + # Reset all nodes + for node in self.nodes.values(): + node.reset_executor_state() + # Reset graph state + self.state = GraphState() + self._resume_from_session = False + return + + try: + self._from_dict(payload) + self._resume_from_session = True + + next_node_ids = payload.get("next_nodes_to_execute") or [] + mapped = self._map_node_ids(next_node_ids) + valid_ready: list[GraphNode] = [] + completed = set(self.state.completed_nodes) + + for node in mapped: + if node in completed or node.execution_status == Status.COMPLETED: + continue + # only include if it's dependency-ready + incoming = [edge for edge in self.edges if edge.to_node == node] + if not incoming: + valid_ready.append(node) + continue + + if all(e.from_node in completed and e.should_traverse(self.state) for e in incoming): + valid_ready.append(node) + + if not valid_ready: + valid_ready = self._compute_ready_nodes_for_resume() + self._resume_next_nodes = valid_ready + logger.debug("Resumed from persisted state. Next nodes: %s", [n.node_id for n in self._resume_next_nodes]) + except Exception: + logger.exception("Failed to apply resume payload") + self._resume_from_session = False + self._resume_next_nodes.clear() + raise diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 7542b1b85..c538cf8f3 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,12 +19,20 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api from ..agent import Agent, AgentResult from ..agent.state import AgentState +from ..experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from ..hooks import HookProvider, HookRegistry +from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool from ..types.content import ContentBlock, Messages @@ -203,6 +211,8 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, + session_manager: Optional[SessionManager] = None, + hooks: Optional[list[HookProvider]] = None, ) -> None: """Initialize Swarm with agents and configuration. @@ -217,6 +227,8 @@ def __init__( Disabled by default (default: 0) repetitive_handoff_min_unique_agents: Minimum unique agents required in recent sequence Disabled by default (default: 0) + session_manager: Optional session manager for persistence + hooks: Optional hook registry for event handling """ super().__init__() @@ -237,9 +249,21 @@ def __init__( ) self.tracer = get_tracer() + self.session_manager = session_manager + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + if self.session_manager is not None: + self.hooks.add_hook(self.session_manager) + self._setup_swarm(nodes) self._inject_swarm_tools() + self._resume_from_session = False + + self.hooks.invoke_callbacks(MultiAgentInitializedEvent(source=self)) + def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: @@ -278,18 +302,18 @@ async def invoke_async( logger.debug("starting swarm execution") - # Initialize swarm state with configuration - if self.entry_point: - initial_node = self.nodes[str(self.entry_point.name)] - else: - initial_node = next(iter(self.nodes.values())) # First SwarmNode + if not self._resume_from_session: + initial_node = self._initial_node() - self.state = SwarmState( - current_node=initial_node, - task=task, - completion_status=Status.EXECUTING, - shared_context=self.shared_context, - ) + self.state = SwarmState( + current_node=initial_node, + task=task, + completion_status=Status.EXECUTING, + shared_context=self.shared_context, + ) + else: + self.state.completion_status = Status.EXECUTING + self.state.start_time = time.time() start_time = time.time() span = self.tracer.start_multiagent_span(task, "swarm") @@ -310,6 +334,10 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) + self.hooks.invoke_callbacks( + AfterMultiAgentInvocationEvent(source=self, invocation_state=invocation_state) + ) + self._resume_from_session = False return self._build_result() @@ -458,6 +486,13 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st previous_agent.node_id, target_node.node_id, ) + # Persist handoff msg incase we lose it. + if self.session_manager is not None: + try: + self.session_manager.write_multi_agent_json(self) + except Exception as e: + logger.warning("Failed to persist swarm state after handoff: %s", e) + raise def _build_node_input(self, target_node: SwarmNode) -> str: """Build input text for a node based on shared context and handoffs. @@ -580,6 +615,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: logger.debug("node=<%s> | node execution completed", current_node.node_id) + self.hooks.invoke_callbacks( + AfterNodeCallEvent(self, node_id=current_node.node_id, invocation_state=invocation_state) + ) + # Check if the current node is still the same after execution # If it is, then no handoff occurred and we consider the swarm complete if self.state.current_node == current_node: @@ -617,8 +656,11 @@ async def _execute_node( self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] ) -> AgentResult: """Execute swarm node.""" + self.hooks.invoke_callbacks( + BeforeNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) start_time = time.time() - node_name = node.node_id + node_id = node.node_id try: # Prepare context for node @@ -633,7 +675,6 @@ async def _execute_node( node_input = node_input + task # Execute node - result = None node.reset_executor_state() result = await node.executor.invoke_async(node_input, invocation_state=invocation_state) @@ -664,7 +705,7 @@ async def _execute_node( ) # Store result in state - self.state.results[node_name] = node_result + self.state.results[node_id] = node_result # Accumulate metrics self._accumulate_metrics(node_result) @@ -673,7 +714,7 @@ async def _execute_node( except Exception as e: execution_time = round((time.time() - start_time) * 1000) - logger.exception("node=<%s> | node execution failed", node_name) + logger.exception("node=<%s> | node execution failed", node_id) # Create a NodeResult for the failed node node_result = NodeResult( @@ -686,7 +727,10 @@ async def _execute_node( ) # Store result in state - self.state.results[node_name] = node_result + self.state.results[node_id] = node_result + + # Persist failure here + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=node_id, invocation_state=invocation_state)) raise @@ -708,3 +752,129 @@ def _build_result(self) -> SwarmResult: execution_time=self.state.execution_time, node_history=self.state.node_history, ) + + # Persistence Helper function + + def _initial_node(self) -> SwarmNode: + if self.entry_point: + return self.nodes[str(self.entry_point.name)] + + return next(iter(self.nodes.values())) # First SwarmNode + + def _to_dict(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state. + + Returns: + Dictionary containing the current swarm state + """ + status_str = getattr(self.state.completion_status, "value", str(self.state.completion_status)) + next_nodes = ( + [self.state.current_node.node_id] + if self.state.completion_status == Status.EXECUTING and self.state.current_node + else [] + ) + # Handoff: 'dict' object is not callable, that's why we need this. + normalized_results: dict[str, NodeResult] = {} + for node_id, entry in (self.state.results or {}).items(): + if isinstance(entry, NodeResult): + normalized_results[node_id] = entry + else: + raise TypeError(f"Unexpected node_result type for {node_id}: {type(entry).__name__}") + return { + "type": "swarm", + "status": status_str, + "node_history": [n.node_id for n in self.state.node_history], + "node_results": {k: v.to_dict() for k, v in normalized_results.items()}, + "next_nodes_to_execute": next_nodes, + "current_task": self.state.task, + "context": { + "shared_context": getattr(self.state.shared_context, "context", {}) or {}, + "handoff_message": self.state.handoff_message, + }, + } + + def _from_dict(self, payload: dict) -> None: + """Restore orchestrator state from a session dict. + + Args: + payload: Dictionary containing persisted state data + """ + try: + status_raw = payload.get("status", "pending") + try: + self.state.completion_status = Status(status_raw) + except Exception: + self.state.completion_status = Status.PENDING + + ctx = payload.get("context") or {} + self.shared_context.context = ctx.get("shared_context") or {} + self.state.handoff_message = ctx.get("handoff_message") + + # node history and results + self.state.node_history = [ + self.nodes[nid] for nid in (payload.get("node_history") or []) if nid in self.nodes + ] + raw_results = payload.get("node_results") or {} + results: dict[str, NodeResult] = {} + for node_id, entry in raw_results.items(): + if node_id not in self.nodes: + continue + try: + results[node_id] = NodeResult.from_dict(entry) + except Exception: + logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) + raise + self.state.results = results + self.state.task = payload.get("current_task", self.state.task) + + # Determine current node (if executing) + next_ids = list(payload.get("next_nodes_to_execute") or []) + if next_ids: + nid = next_ids[0] + found_node = self.nodes.get(nid) + self.state.current_node = found_node if found_node is not None else self._initial_node() + else: + # fallback to last executed or first node + last = (payload.get("node_history") or [])[-1:] or [] + if last: + found_node = self.nodes.get(last[0]) + self.state.current_node = found_node if found_node is not None else self._initial_node() + else: + self.state.current_node = self._initial_node() + + except Exception as e: + logger.exception("Failed to apply persisted swarm state: %s", e) + + def serialize_state(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state. + + Returns: + Dictionary containing the current graph state + """ + return self._to_dict() + + def deserialize_state(self, payload: dict) -> None: + """Restore orchestrator state from a session dict. + + Args: + payload: Dictionary containing persisted state data + """ + if not payload.get("next_nodes_to_execute"): + # Reset all nodes + for node in self.nodes.values(): + node.reset_executor_state() + # Reset graph state + self.state = SwarmState( + current_node=SwarmNode("", Agent()), # Placeholder, will be set properly + task="", + completion_status=Status.PENDING, + ) + self._resume_from_session = False + return + else: + self._from_dict(payload) + self._resume_from_session = True + logger.debug( + "Resumed from persisted state. Current node: %s", + self.state.current_node.node_id if self.state.current_node else "None", + ) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 491f7ad60..954ccc416 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -5,14 +5,18 @@ import os import shutil import tempfile -from typing import Any, Optional, cast +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Optional, cast from .. import _identifier from ..types.exceptions import SessionException -from ..types.session import Session, SessionAgent, SessionMessage +from ..types.session import Session, SessionAgent, SessionMessage, SessionType from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" @@ -37,19 +41,27 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): ``` """ - def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): + def __init__( + self, + session_id: str, + storage_dir: Optional[str] = None, + *, + session_type: SessionType = SessionType.AGENT, + **kwargs: Any, + ): """Initialize FileSession with filesystem storage. Args: session_id: ID for the session. ID is not allowed to contain path separators (e.g., a/b). storage_dir: Directory for local filesystem storage (defaults to temp dir). + session_type: single agent or multiagent. **kwargs: Additional keyword arguments for future extensibility. """ self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") os.makedirs(self.storage_dir, exist_ok=True) - super().__init__(session_id=session_id, session_repository=self) + super().__init__(session_id=session_id, session_repository=self, session_type=session_type) def _get_session_path(self, session_id: str) -> str: """Get session directory path. @@ -107,8 +119,10 @@ def _read_file(self, path: str) -> dict[str, Any]: def _write_file(self, path: str, data: dict[str, Any]) -> None: """Write JSON file.""" os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as f: + tmp = f"{path}.tmp" + with open(tmp, "w", encoding="utf-8", newline="\n") as f: json.dump(data, f, indent=2, ensure_ascii=False) + os.replace(tmp, path) def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" @@ -118,7 +132,8 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: # Create directory structure os.makedirs(session_dir, exist_ok=True) - os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) + if self.session_type == SessionType.AGENT: + os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) # Write session file session_file = os.path.join(session_dir, "session.json") @@ -239,3 +254,33 @@ def list_messages( messages.append(SessionMessage.from_dict(message_data)) return messages + + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + """Write multi-agent state to filesystem. + + Args: + source: Multi-agent source object to persist + """ + state = source.serialize_state() + state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") + self._write_file(state_path, state) + + # Update session update metadata + session_dir = self._get_session_path(self.session.session_id) + session_file = os.path.join(session_dir, "session.json") + with open(session_file, "r", encoding="utf-8") as f: + metadata = json.load(f) + metadata["updated_at"] = datetime.now(timezone.utc).isoformat() + self._write_file(session_file, metadata) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from filesystem. + + Returns: + Multi-agent state dictionary or None if not found + """ + state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") + if not os.path.exists(state_path): + return {} + state_data = self._read_file(state_path) + return state_data diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index e5075de93..01e23dfec 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -24,7 +24,13 @@ class RepositorySessionManager(SessionManager): """Session manager for persisting agents in a SessionRepository.""" - def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any): + def __init__( + self, + session_id: str, + session_repository: SessionRepository, + session_type: SessionType = SessionType.AGENT, + **kwargs: Any, + ): """Initialize the RepositorySessionManager. If no session with the specified session_id exists yet, it will be created @@ -34,22 +40,27 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa session_id: ID to use for the session. A new session with this id will be created if it does not exist in the repository yet session_repository: Underlying session repository to use to store the sessions state. + session_type: Type of session (AGENT or MULTI_AGENT) **kwargs: Additional keyword arguments for future extensibility. """ + super().__init__(session_type=session_type) + self.session_repository = session_repository self.session_id = session_id session = session_repository.read_session(session_id) # Create a session if it does not exist yet if session is None: logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) - session = Session(session_id=session_id, session_type=SessionType.AGENT) + session = Session(session_id=session_id, session_type=session_type) session_repository.create_session(session) self.session = session + self.session_type = session.session_type - # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + if self.session_type == SessionType.AGENT: + # Keep track of the latest message of each agent in case we need to redact it. + self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index c6ce28d80..a28d1b2d7 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -10,10 +10,13 @@ from .. import _identifier from ..types.exceptions import SessionException -from ..types.session import Session, SessionAgent, SessionMessage +from ..types.session import Session, SessionAgent, SessionMessage, SessionType from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" @@ -46,6 +49,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + session_type: SessionType = SessionType.AGENT, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -58,6 +62,7 @@ def __init__( boto_session: Optional boto3 session boto_client_config: Optional boto3 client configuration region_name: AWS region for S3 storage + session_type: Type of session (AGENT or MULTI_AGENT) **kwargs: Additional keyword arguments for future extensibility. """ self.bucket = bucket @@ -78,7 +83,12 @@ def __init__( client_config = BotocoreConfig(user_agent_extra="strands-agents") self.client = session.client(service_name="s3", config=client_config) - super().__init__(session_id=session_id, session_repository=self) + super().__init__(session_id=session_id, session_type=session_type, session_repository=self) + + # This avoids leading // or / when self.prefix ="" + def _join_key(self, *parts: str) -> str: + cleaned = [part.strip("/ ") for part in parts if part and part.strip("/ ")] + return "/".join(cleaned) def _get_session_path(self, session_id: str) -> str: """Get session S3 prefix. @@ -90,7 +100,7 @@ def _get_session_path(self, session_id: str) -> str: ValueError: If session id contains a path separator. """ session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) - return f"{self.prefix}/{SESSION_PREFIX}{session_id}/" + return self._join_key(self.prefix, f"{SESSION_PREFIX}{session_id}") def _get_agent_path(self, session_id: str, agent_id: str) -> str: """Get agent S3 prefix. @@ -294,3 +304,24 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e + + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + """Write multi-agent state to S3. + + Args: + source: Multi-agent source object to persist + """ + session_prefix = self._get_session_path(self.session_id) + state_key = self._join_key(session_prefix, "multi_agent_state.json") + state = source.serialize_state() + self._write_s3_object(state_key, state) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from S3. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + session_prefix = self._get_session_path(self.session_id) + state_key = self._join_key(session_prefix, "multi_agent_state.json") + return self._read_s3_object(state_key) or {} diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 66a07ea43..8b84e0790 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -1,14 +1,24 @@ """Session manager interface for agent session management.""" +import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ..experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message +from ..types.session import SessionType if TYPE_CHECKING: from ..agent.agent import Agent + from ..multiagent.base import MultiAgentBase + +logger = logging.getLogger(__name__) class SessionManager(HookProvider, ABC): @@ -20,19 +30,39 @@ class SessionManager(HookProvider, ABC): for an agent, and should be persisted in the session. """ + def __init__(self, session_type: SessionType = SessionType.AGENT) -> None: + """Initialize SessionManager with session type. + + Args: + session_type: Type of session (AGENT or MULTI_AGENT) + """ + self.session_type: SessionType = session_type + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for persisting the agent to the session.""" - # After the normal Agent initialization behavior, call the session initialize function to restore the agent - registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + if not hasattr(self, "session_type"): + self.session_type = SessionType.AGENT + logger.debug("Session type not set, defaulting to AGENT") + + if self.session_type == SessionType.AGENT: + # After the normal Agent initialization behavior, call the session initialize function to restore the agent + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + + # For each message appended to the Agents messages, store that message in the session + registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) - # For each message appended to the Agents messages, store that message in the session - registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) + # Sync the agent into the session for each message in case the agent state was updated + registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) - # Sync the agent into the session for each message in case the agent state was updated - registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + # After an agent was invoked, sync it with the session to capture any conversation manager state updates + registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) - # After an agent was invoked, sync it with the session to capture any conversation manager state updates - registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) + elif self.session_type == SessionType.MULTI_AGENT: + registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized) + registry.add_callback(AfterNodeCallEvent, lambda event: self.write_multi_agent_json(event.source)) + registry.add_callback( + AfterMultiAgentInvocationEvent, lambda event: self.write_multi_agent_json(event.source) + ) @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: @@ -71,3 +101,37 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent: Agent to initialize **kwargs: Additional keyword arguments for future extensibility. """ + + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: + """Write multi-agent state to persistent storage. + + Args: + source: Multi-agent source object to persist + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(write_multi_agent_json). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from persistent storage. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(read_multi_agent_json). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) + + def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: + """Initialization path: attempt to resume and then persist a fresh snapshot.""" + source: MultiAgentBase = event.source + payload = self.read_multi_agent_json() + # payload can be {} or Graph/Swarm state json + if payload: + source.deserialize_state(payload) + else: + self.write_multi_agent_json(source) diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 926480f2c..e0e8f396c 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -22,6 +22,7 @@ class SessionType(str, Enum): """ AGENT = "AGENT" + MULTI_AGENT = "MULTI_AGENT" def encode_bytes_values(obj: Any) -> Any: diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py new file mode 100644 index 000000000..727d28a48 --- /dev/null +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -0,0 +1,41 @@ +from typing import Iterator, Literal, Tuple, Type + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import ( + HookEvent, + HookProvider, + HookRegistry, +) + + +class MockMultiAgentHookProvider(HookProvider): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + MultiAgentInitializedEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, + AfterMultiAgentInvocationEvent, + ] + + self.events_received = [] + self.events_types = event_types + + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + return len(self.events_received), iter(self.events_received) + + def register_hooks(self, registry: HookRegistry) -> None: + for event_type in self.events_types: + registry.add_callback(event_type, self.add_event) + + def add_event(self, event: HookEvent) -> None: + self.events_received.append(event) diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 409b08a2d..8806baae0 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -95,3 +95,31 @@ def test__str__non_dict_content(mock_metrics): message_string = str(result) assert message_string == "Valid text\nMore valid text\n" + + +def test_from_dict_valid_data(): + """Test that from_dict works with valid data.""" + data = { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + result = AgentResult.from_dict(data) + + assert result.message == data["message"] + assert result.stop_reason == data["stop_reason"] + assert isinstance(result.metrics, EventLoopMetrics) + assert result.state == {} + + +def test_from_dict_invalid_type(): + """Test that from_dict raises TypeError for invalid type.""" + data = { + "type": "invalid_type", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + with pytest.raises(TypeError, match="AgentResult.from_dict: unexpected type 'invalid_type'"): + AgentResult.from_dict(data) diff --git a/tests/strands/experimental/hooks/multiagent/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/hooks/multiagent/test_multiagent_events.py b/tests/strands/experimental/hooks/multiagent/test_multiagent_events.py new file mode 100644 index 000000000..6d2b22955 --- /dev/null +++ b/tests/strands/experimental/hooks/multiagent/test_multiagent_events.py @@ -0,0 +1,76 @@ +"""Tests for multi-agent execution lifecycle events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks.registry import BaseHookEvent + + +@pytest.fixture +def orchestrator(): + """Mock orchestrator for testing.""" + return Mock() + + +def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): + """Test MultiAgentInitializedEvent creation with orchestrator only.""" + event = MultiAgentInitializedEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_multi_agent_initialization_event_with_invocation_state(orchestrator): + """Test MultiAgentInitializedEvent creation with invocation state.""" + invocation_state = {"key": "value"} + event = MultiAgentInitializedEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state + + +def test_after_node_invocation_event_with_required_fields(orchestrator): + """Test AfterNodeCallEvent creation with required fields.""" + node_id = "node_1" + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_node_invocation_event_with_invocation_state(orchestrator): + """Test AfterNodeCallEvent creation with invocation state.""" + node_id = "node_2" + invocation_state = {"result": "success"} + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.node_id == node_id + assert event.invocation_state == invocation_state + + +def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with orchestrator only.""" + event = AfterMultiAgentInvocationEvent(source=orchestrator) + + assert event.source is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_multi_agent_invocation_event_with_invocation_state(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with invocation state.""" + invocation_state = {"final_state": "completed"} + event = AfterMultiAgentInvocationEvent(source=orchestrator, invocation_state=invocation_state) + + assert event.source is orchestrator + assert event.invocation_state == invocation_state diff --git a/tests/strands/experimental/hooks/multiagent/test_multiagent_hooks.py b/tests/strands/experimental/hooks/multiagent/test_multiagent_hooks.py new file mode 100644 index 000000000..cfdbfe30f --- /dev/null +++ b/tests/strands/experimental/hooks/multiagent/test_multiagent_hooks.py @@ -0,0 +1,127 @@ +import pytest + +from strands import Agent +from strands.experimental.hooks.multiagent.events import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.multiagent.graph import Graph, GraphBuilder +from strands.multiagent.swarm import Swarm +from tests.fixtures.mock_multiagent_hook_provider import MockMultiAgentHookProvider +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +@pytest.fixture +def hook_provider(): + return MockMultiAgentHookProvider( + [ + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, + ] + ) + + +@pytest.fixture +def mock_model(): + agent_messages = [ + {"role": "assistant", "content": [{"text": "Task completed"}]}, + {"role": "assistant", "content": [{"text": "Task completed by agent 2"}]}, + {"role": "assistant", "content": [{"text": "Additional response"}]}, + ] + return MockedModelProvider(agent_messages) + + +@pytest.fixture +def agent0(mock_model): + return Agent(model=mock_model, system_prompt="You are a helpful assistant.", name="agent0") + + +@pytest.fixture +def agent1(mock_model): + return Agent(model=mock_model, system_prompt="You are agent 1.", name="agent1") + + +@pytest.fixture +def agent2(mock_model): + return Agent(model=mock_model, system_prompt="You are agent 2.", name="agent2") + + +@pytest.fixture +def swarm(agent1, agent2, hook_provider): + swarm = Swarm(nodes=[agent1, agent2], hooks=[hook_provider]) + return swarm + + +@pytest.fixture +def graph(agent1, agent2, hook_provider): + builder = GraphBuilder() + builder.add_node(agent1, "agent1") + builder.add_node(agent2, "agent2") + builder.add_edge("agent1", "agent2") + builder.set_entry_point("agent1") + graph = Graph(nodes=builder.nodes, edges=builder.edges, entry_points=builder.entry_points, hooks=[hook_provider]) + return graph + + +def test_swarm_complete_hook_lifecycle(swarm, hook_provider): + """E2E test verifying complete hook lifecycle for Swarm.""" + result = swarm("test task") + + length, events = hook_provider.get_events() + assert length == 4 + assert result.status.value == "completed" + + events_list = list(events) + + # Check event types and basic properties, ignoring invocation_state + assert isinstance(events_list[0], MultiAgentInitializedEvent) + assert events_list[0].source == swarm + + assert isinstance(events_list[1], BeforeNodeCallEvent) + assert events_list[1].source == swarm + assert events_list[1].node_id == "agent1" + + assert isinstance(events_list[2], AfterNodeCallEvent) + assert events_list[2].source == swarm + assert events_list[2].node_id == "agent1" + + assert isinstance(events_list[3], AfterMultiAgentInvocationEvent) + assert events_list[3].source == swarm + + +def test_graph_complete_hook_lifecycle(graph, hook_provider): + """E2E test verifying complete hook lifecycle for Graph.""" + result = graph("test task") + + length, events = hook_provider.get_events() + assert length == 6 + assert result.status.value == "completed" + + events_list = list(events) + + # Check event types and basic properties, ignoring invocation_state + assert isinstance(events_list[0], MultiAgentInitializedEvent) + assert events_list[0].source == graph + + assert isinstance(events_list[1], BeforeNodeCallEvent) + assert events_list[1].source == graph + assert events_list[1].node_id == "agent1" + + assert isinstance(events_list[2], AfterNodeCallEvent) + assert events_list[2].source == graph + assert events_list[2].node_id == "agent1" + + assert isinstance(events_list[3], BeforeNodeCallEvent) + assert events_list[3].source == graph + assert events_list[3].node_id == "agent2" + + assert isinstance(events_list[4], AfterNodeCallEvent) + assert events_list[4].source == graph + assert events_list[4].node_id == "agent2" + + assert isinstance(events_list[5], AfterMultiAgentInvocationEvent) + assert events_list[5].source == graph diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index ab55b2c84..0da071afa 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -28,6 +28,10 @@ def test_node_result_initialization_and_properties(agent_result): assert node_result.accumulated_metrics == {"latencyMs": 0.0} assert node_result.execution_count == 0 + # Test default status + default_node = NodeResult(result=agent_result) + assert default_node.status == Status.PENDING + # With custom metrics custom_usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300} custom_metrics = {"latencyMs": 250.0} @@ -95,6 +99,7 @@ def test_multi_agent_result_initialization(agent_result): assert result.accumulated_metrics == {"latencyMs": 0.0} assert result.execution_count == 0 assert result.execution_time == 0 + assert result.status == Status.PENDING # Custom values`` node_result = NodeResult(result=agent_result) @@ -141,6 +146,12 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) @@ -164,6 +175,15 @@ async def invoke_async(self, task, invocation_state, **kwargs): status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) + def serialize_state(self) -> dict: + return {} + + def deserialize_state(self, payload: dict) -> None: + pass + + def attempt_resume(self, payload: dict) -> None: + pass + agent = TestMultiAgent() # Test with string task @@ -174,3 +194,57 @@ async def invoke_async(self, task, invocation_state, **kwargs): assert agent.received_invocation_state == {"param1": "value1", "param2": "value2", "value3": "value4"} assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED + + +def test_node_result_to_dict(agent_result): + """Test NodeResult to_dict method.""" + # Test with AgentResult + node_result = NodeResult(result=agent_result, execution_time=100, status=Status.COMPLETED) + result_dict = node_result.to_dict() + + assert result_dict["execution_time"] == 100 + assert result_dict["status"] == "completed" + assert result_dict["result"]["type"] == "agent_result" + assert result_dict["result"]["stop_reason"] == agent_result.stop_reason + assert result_dict["result"]["message"] == agent_result.message + + # Test with Exception + exception_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + result_dict = exception_result.to_dict() + + assert result_dict["result"]["type"] == "exception" + assert result_dict["result"]["message"] == "Test error" + assert result_dict["status"] == "failed" + + +def test_multi_agent_result_to_dict(agent_result): + """Test MultiAgentResult to_dict method.""" + node_result = NodeResult(result=agent_result) + multi_result = MultiAgentResult(status=Status.COMPLETED, results={"test_node": node_result}, execution_time=200) + + result_dict = multi_result.to_dict() + + assert result_dict["status"] == "completed" + assert result_dict["execution_time"] == 200 + assert "test_node" in result_dict["results"] + assert result_dict["results"]["test_node"]["result"]["type"] == "agent_result" + + +def test_serialize_node_result_for_persist(agent_result): + """Test serialize_node_result_for_persist method.""" + + # Test with NodeResult containing AgentResult + node_result = NodeResult(result=agent_result) + serialized = node_result.to_dict() + + # Should return the to_dict() result + assert "result" in serialized + assert "execution_time" in serialized + assert "status" in serialized + + # Test with NodeResult containing Exception + exception_node_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + serialized_exception = exception_node_result.to_dict() + assert "result" in serialized_exception + assert serialized_exception["result"]["type"] == "exception" + assert serialized_exception["result"]["message"] == "Test error" diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index c4c1a664f..1e44d0022 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -640,7 +640,7 @@ async def timeout_invoke(*args, **kwargs): graph = builder.set_max_node_executions(50).set_execution_timeout(900.0).set_node_timeout(0.1).build() # Execute the graph - should raise Exception due to timeout - with pytest.raises(Exception, match="Node 'timeout_node' execution timed out after 0.1s"): + with pytest.raises(asyncio.TimeoutError): await graph.invoke_async("Test node timeout") mock_strands_tracer.start_multiagent_span.assert_called() @@ -1341,3 +1341,54 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): [{"text": "Test kwargs passing sync"}], invocation_state=test_invocation_state ) assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_persisted(mock_strands_tracer, mock_use_span): + """Test graph persistence functionality.""" + # Create mock session manager + session_manager = Mock(spec=SessionManager) + session_manager.read_multi_agent_json.return_value = None + + # Create simple graph with session manager + builder = GraphBuilder() + agent = create_mock_agent("test_agent") + builder.add_node(agent, "test_node") + builder.set_entry_point("test_node") + builder.set_session_manager(session_manager) + + graph = builder.build() + + # Test get_state_from_orchestrator + state = graph.serialize_state() + assert state["type"] == "graph" + assert "status" in state + assert "completed_nodes" in state + assert "node_results" in state + + # Test apply_state_from_dict with persisted state + persisted_state = { + "status": "executing", + "completed_nodes": [], + "node_results": {}, + "current_task": "persisted task", + "execution_order": [], + "next_nodes_to_execute": ["test_node"], + } + + graph._from_dict(persisted_state) + assert graph.state.task == "persisted task" + + # Execute graph to test persistence integration + result = await graph.invoke_async("Test persistence") + + # Verify execution completed + assert result.status == Status.COMPLETED + assert len(result.results) == 1 + assert "test_node" in result.results + + # Test state serialization after execution + final_state = graph.serialize_state() + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 1 + assert "test_node" in final_state["node_results"] diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 0968fd30c..2960d01fc 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -574,3 +574,52 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): assert kwargs_agent.invoke_async.call_args.kwargs == {"invocation_state": test_kwargs} assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_swarm_persistence(mock_strands_tracer, mock_use_span): + """Test swarm persistence functionality.""" + # Create mock session manager + session_manager = Mock(spec=SessionManager) + session_manager.read_multi_agent_json.return_value = None + + # Create simple swarm with session manager + agent = create_mock_agent("test_agent") + swarm = Swarm([agent], session_manager=session_manager) + + # Test get_state_from_orchestrator + state = swarm.serialize_state() + assert state["type"] == "swarm" + assert "status" in state + assert "node_history" in state + assert "node_results" in state + assert "context" in state + + # Test apply_state_from_dict with persisted state + persisted_state = { + "status": "executing", + "node_history": [], + "node_results": {}, + "current_task": "persisted task", + "next_nodes_to_execute": ["test_agent"], + "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, + } + + swarm._from_dict(persisted_state) + assert swarm.state.task == "persisted task" + assert swarm.state.handoff_message == "test handoff" + assert swarm.shared_context.context["test_agent"]["key"] == "value" + + # Execute swarm to test persistence integration + result = await swarm.invoke_async("Test persistence") + + # Verify execution completed + assert result.status == Status.COMPLETED + assert len(result.results) == 1 + assert "test_agent" in result.results + + # Test state serialization after execution + final_state = swarm.serialize_state() + assert final_state["status"] == "completed" + assert len(final_state["node_history"]) == 1 + assert "test_agent" in final_state["node_results"] diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f124ddf58..61b20aba4 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -408,3 +408,52 @@ def test__get_message_path_invalid_message_id(message_id, file_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): file_manager._get_message_path("session1", "agent1", message_id) + + +def test_write_read_multi_agent_json(file_manager, sample_session): + """Test writing and reading multi-agent state.""" + file_manager.create_session(sample_session) + + # Create mock MultiAgentBase object + class MockMultiAgent: + def serialize_state(self): + return {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + + mock_agent = MockMultiAgent() + expected_state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + + # Write multi-agent state + file_manager.write_multi_agent_json(mock_agent) + + # Read multi-agent state + result = file_manager.read_multi_agent_json() + assert result == expected_state + + +def test_read_multi_agent_json_nonexistent(file_manager): + """Test reading multi-agent state when file doesn't exist.""" + result = file_manager.read_multi_agent_json() + assert result == {} + + +def test_list_messages_missing_directory(file_manager, sample_session, sample_agent): + """Test listing messages when messages directory is missing.""" + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Remove messages directory + messages_dir = os.path.join( + file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id), "messages" + ) + os.rmdir(messages_dir) + + with pytest.raises(SessionException, match="Messages directory missing"): + file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + +def test_create_existing_session(file_manager, sample_session): + """Test creating a session that already exists.""" + file_manager.create_session(sample_session) + + with pytest.raises(SessionException, match="already exists"): + file_manager.create_session(sample_session) diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index c4d6a0154..d5f2b7d97 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -374,3 +374,27 @@ def test__get_message_path_invalid_message_id(message_id, s3_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): s3_manager._get_message_path("session1", "agent1", message_id) + + +def test_write_read_multi_agent_json(s3_manager, sample_session): + """Test multi-agent state persistence.""" + s3_manager.create_session(sample_session) + + # Create mock MultiAgentBase object + class MockMultiAgent: + def serialize_state(self): + return {"type": "graph", "status": "completed"} + + mock_agent = MockMultiAgent() + expected_state = {"type": "graph", "status": "completed"} + + s3_manager.write_multi_agent_json(mock_agent) + + result = s3_manager.read_multi_agent_json() + assert result == expected_state + + +def test_read_multi_agent_json_nonexistent(s3_manager): + """Test reading multi-agent state when file doesn't exist.""" + result = s3_manager.read_multi_agent_json() + assert result == {} diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index c2c13c443..b8329b49a 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,3 +1,7 @@ +import tempfile +from unittest.mock import patch +from uuid import uuid4 + import pytest from strands import Agent, tool @@ -9,8 +13,11 @@ BeforeModelCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import Status from strands.multiagent.graph import GraphBuilder +from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock +from strands.types.session import SessionType from tests.fixtures.mock_hook_provider import MockHookProvider @@ -83,6 +90,13 @@ def image_analysis_agent(hook_provider): ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + @pytest.fixture def nested_computation_graph(math_agent, analysis_agent): """Create a nested graph for mathematical computation and analysis.""" @@ -218,3 +232,74 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events + + +@pytest.mark.asyncio +async def test_graph_interrupt_and_resume(): + """Test graph interruption and resume functionality with FileSessionManager.""" + + with tempfile.TemporaryDirectory() as temp_dir: + session_id = str(uuid4()) + + # Create real agents + agent1 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 1", name="agent1") + agent2 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 2", name="agent2") + agent3 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 3", name="agent3") + + session_manager = FileSessionManager( + session_id=session_id, storage_dir=temp_dir, session_type=SessionType.MULTI_AGENT + ) + + builder = GraphBuilder() + builder.add_node(agent1, "node1") + builder.add_node(agent2, "node2") + builder.add_node(agent3, "node3") + builder.add_edge("node1", "node2") + builder.add_edge("node2", "node3") + builder.set_entry_point("node1") + builder.set_session_manager(session_manager) + + graph = builder.build() + + # Mock agent2 to fail on first execution + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure in agent2") + + with patch.object(agent2, "invoke_async", side_effect=failing_invoke): + # First execution - should fail at agent2 + try: + await graph.invoke_async("Test task") + except Exception as e: + assert "Simulated failure in agent2" in str(e) + + # Verify partial execution was persisted + persisted_state = session_manager.read_multi_agent_json() + assert persisted_state is not None + assert persisted_state["type"] == "graph" + assert persisted_state["status"] == "failed" + assert len(persisted_state["completed_nodes"]) == 1 # Only node1 completed + assert "node1" in persisted_state["completed_nodes"] + assert "node2" in persisted_state["next_node_to_execute"] + + # Track execution count before resume + initial_execution_count = graph.state.execution_count + + # Execute graph again + result = await graph.invoke_async("Test task") + + # Verify successful completion + assert result.status == Status.COMPLETED + assert len(result.results) == 3 + + execution_order_ids = [node.node_id for node in result.execution_order] + assert execution_order_ids == ["node1", "node2", "node3"] + + # Verify only 2 additional nodes were executed + assert result.execution_count == initial_execution_count + 2 + + final_state = session_manager.read_multi_agent_json() + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 3 + + # Clean up + session_manager.delete_session(session_id) diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 9a8c79bf8..a80b97f40 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,3 +1,7 @@ +import tempfile +from unittest.mock import patch +from uuid import uuid4 + import pytest from strands import Agent, tool @@ -10,8 +14,11 @@ BeforeToolCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import Status from strands.multiagent.swarm import Swarm +from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock +from strands.types.session import SessionType from tests.fixtures.mock_hook_provider import MockHookProvider @@ -134,3 +141,54 @@ async def test_swarm_execution_with_image(researcher_agent, analyst_agent, write # Verify agent history - at least one agent should have been used assert len(result.node_history) > 0 + + +@pytest.mark.asyncio +async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, writer_agent): + """Test swarm interruption after analyst_agent and resume functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + session_id = str(uuid4()) + + # Create session manager + session_manager = FileSessionManager( + session_id=session_id, storage_dir=temp_dir, session_type=SessionType.MULTI_AGENT + ) + + # Create swarm with session manager + swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager) + + # Mock analyst_agent to fail + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure in analyst") + + with patch.object(analyst_agent, "invoke_async", side_effect=failing_invoke): + # First execution - should fail at analyst + result = await swarm.invoke_async("Research AI trends and create a brief report") + assert result.status == Status.FAILED + + # Verify partial execution was persisted + persisted_state = session_manager.read_multi_agent_json() + assert persisted_state is not None + assert persisted_state["type"] == "swarm" + assert persisted_state["status"] == "failed" + assert len(persisted_state["node_history"]) == 1 # At least researcher executed + + # Track execution count before resume + initial_execution_count = len(persisted_state["node_history"]) + + # Execute swarm again - should automatically resume from saved state + result = await swarm.invoke_async("Research AI trends and create a brief report") + + # Verify successful completion + assert result.status == Status.COMPLETED + assert len(result.results) > 0 + + assert len(result.node_history) >= initial_execution_count + 1 + + node_names = [node.node_id for node in result.node_history] + assert "researcher" in node_names + # Either analyst or writer (or both) should have executed to complete the task + assert "analyst" in node_names or "writer" in node_names + + # Clean up + session_manager.delete_session(session_id)