From c2d49f66ce51a7e953a18c84276d6aac1180a558 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Thu, 25 Sep 2025 10:33:41 -0400 Subject: [PATCH 01/27] feat: multiagent session interface --- .../multiagent_session/__init__.py | 26 +++ .../multiagent_session/multiagent_events.py | 87 +++++++ .../multiagent_session/multiagent_state.py | 110 +++++++++ .../multiagent_state_adapter.py | 213 ++++++++++++++++++ .../multiagent_session/persistence_hooks.py | 106 +++++++++ 5 files changed, 542 insertions(+) create mode 100644 src/strands/experimental/multiagent_session/__init__.py create mode 100644 src/strands/experimental/multiagent_session/multiagent_events.py create mode 100644 src/strands/experimental/multiagent_session/multiagent_state.py create mode 100644 src/strands/experimental/multiagent_session/multiagent_state_adapter.py create mode 100644 src/strands/experimental/multiagent_session/persistence_hooks.py diff --git a/src/strands/experimental/multiagent_session/__init__.py b/src/strands/experimental/multiagent_session/__init__.py new file mode 100644 index 000000000..fbc0e24b4 --- /dev/null +++ b/src/strands/experimental/multiagent_session/__init__.py @@ -0,0 +1,26 @@ +"""Multi-agent session management for persistent execution. + +This package provides session persistence capabilities for multi-agent orchestrators, +enabling resumable execution after interruptions or failures. +""" + +from .multiagent_events import ( + AfterGraphInvocationEvent, + AfterNodeInvocationEvent, + BeforeGraphInvocationEvent, + BeforeNodeInvocationEvent, + MultiAgentInitializationEvent, +) +from .multiagent_state import MultiAgentState, MultiAgentType +from .multiagent_state_adapter import MultiAgentAdapter + +__all__ = [ + "BeforeGraphInvocationEvent", + "AfterGraphInvocationEvent", + "MultiAgentInitializationEvent", + "BeforeNodeInvocationEvent", + "AfterNodeInvocationEvent", + "MultiAgentState", + "MultiAgentAdapter", + "MultiAgentType", +] diff --git a/src/strands/experimental/multiagent_session/multiagent_events.py b/src/strands/experimental/multiagent_session/multiagent_events.py new file mode 100644 index 000000000..ebfb9e48f --- /dev/null +++ b/src/strands/experimental/multiagent_session/multiagent_events.py @@ -0,0 +1,87 @@ +"""Multi-agent execution lifecycle events for hook system integration. + +This module defines event classes that are triggered at key points during +multi-agent orchestrator execution, enabling hooks to respond to lifecycle +events for purposes like persistence, monitoring, and debugging. + +Event Types: +- Initialization: When orchestrator starts up +- Before/After Graph: Start/end of overall execution +- Before/After Node: Start/end of individual node execution +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ...hooks.registry import HookEvent +from .multiagent_state import MultiAgentState + +if TYPE_CHECKING: + from ...multiagent.base import MultiAgentBase + + +@dataclass +class MultiAgentInitializationEvent(HookEvent): + """Event triggered when multi-agent orchestrator initializes. + + Attributes: + orchestrator: The multi-agent orchestrator instance + state: Current state of the orchestrator + """ + + orchestrator: "MultiAgentBase" + state: MultiAgentState + + +@dataclass +class BeforeGraphInvocationEvent(HookEvent): + """Event triggered before orchestrator execution begins. + + Attributes: + orchestrator: The multi-agent orchestrator instance + state: Current state before execution starts + """ + + orchestrator: "MultiAgentBase" + state: MultiAgentState + + +@dataclass +class BeforeNodeInvocationEvent(HookEvent): + """Event triggered before individual node execution. + + Attributes: + orchestrator: The multi-agent orchestrator instance + next_node_to_execute: ID of the node about to be executed + """ + + orchestrator: "MultiAgentBase" + next_node_to_execute: str + + +@dataclass +class AfterNodeInvocationEvent(HookEvent): + """Event triggered after individual node execution completes. + + Attributes: + orchestrator: The multi-agent orchestrator instance + executed_node: ID of the node that just completed execution + state: Updated state after node execution + """ + + orchestrator: "MultiAgentBase" + executed_node: str + state: MultiAgentState + + +@dataclass +class AfterGraphInvocationEvent(HookEvent): + """Event triggered after orchestrator execution completes. + + Attributes: + orchestrator: The multi-agent orchestrator instance + state: Final state after execution completes + """ + + orchestrator: "MultiAgentBase" + state: MultiAgentState diff --git a/src/strands/experimental/multiagent_session/multiagent_state.py b/src/strands/experimental/multiagent_session/multiagent_state.py new file mode 100644 index 000000000..60bbef6cb --- /dev/null +++ b/src/strands/experimental/multiagent_session/multiagent_state.py @@ -0,0 +1,110 @@ +"""Multi-agent state data structures for session persistence. + +This module defines the core data structures used to represent the state +of multi-agent orchestrators in a serializable format for session persistence. + +Key Components: +- MultiAgentType: Enum for orchestrator types (Graph/Swarm) +- MultiAgentState: Serializable state container with conversion methods +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set + +from ...types.content import ContentBlock + +if TYPE_CHECKING: + from ...multiagent.base import Status + + +# TODO: Move to Base after experimental +class MultiAgentType(Enum): + """Enumeration of supported multi-agent orchestrator types. + + Attributes: + SWARM: Collaborative agent swarm orchestrator + GRAPH: Directed graph-based agent orchestrator + """ + + SWARM = "swarm" + GRAPH = "graph" + + +@dataclass +class MultiAgentState: + """Serializable state container for multi-agent orchestrators. + + This class represents the complete execution state of a multi-agent + orchestrator (Graph or Swarm) in a format suitable for persistence + and restoration across sessions. + + Attributes: + completed_nodes: Set of node IDs that have completed execution + node_results: Dictionary mapping node IDs to their execution results + status: Current execution status of the orchestrator + next_node_to_execute: List of node IDs ready for execution + current_task: The original task being executed + execution_order: Ordered list of executed node IDs + error_message: Optional error message if execution failed + type: Type of orchestrator (Graph or Swarm) + context: Additional context data (primarily for Swarm) + """ + + # Mutual + completed_nodes: Set[str] = field(default_factory=set) + node_results: Dict[str, Any] = field(default_factory=dict) + status: "Status" = "pending" + next_node_to_execute: Optional[List[str]] = None + current_task: Optional[str | List[ContentBlock]] = None + execution_order: list[str] = field(default_factory=list) + error_message: Optional[str] = None + type: Optional[MultiAgentType] = field(default=MultiAgentType.GRAPH) + # Swarm + context: Optional[dict] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert MultiAgentState to JSON-serializable dictionary. + + Returns: + Dictionary representation suitable for JSON serialization + """ + + def _serialize(v: Any) -> Any: + if isinstance(v, (str, int, float, bool)) or v is None: + return v + if isinstance(v, set): + return list(v) + if isinstance(v, dict): + return {str(k): _serialize(val) for k, val in v.items()} + if isinstance(v, (list, tuple)): + return [_serialize(x) for x in v] + if hasattr(v, "to_dict"): + return v.to_dict() + # last resort: stringize anything non-serializable (locks, objects, etc.) + return str(v) + + return { + "status": self.status, + "completed_nodes": list(self.completed_nodes), + "next_node_to_execute": list(self.next_node_to_execute) if self.next_node_to_execute else [], + "node_results": _serialize(self.node_results), + "current_task": self.current_task, + "error_message": self.error_message, + "execution_order": self.execution_order, + "type": self.type, + "context": _serialize(self.context), + } + + @classmethod + def from_dict(cls, data: dict): + """Create MultiAgentState from dictionary data. + + Args: + data: Dictionary containing state data + + Returns: + MultiAgentState instance + """ + data["completed_nodes"] = set(data.get("completed_nodes", [])) + return cls(**data) diff --git a/src/strands/experimental/multiagent_session/multiagent_state_adapter.py b/src/strands/experimental/multiagent_session/multiagent_state_adapter.py new file mode 100644 index 000000000..6d9cbb145 --- /dev/null +++ b/src/strands/experimental/multiagent_session/multiagent_state_adapter.py @@ -0,0 +1,213 @@ +"""Multi-agent state adapter for session persistence. + +This module provides bidirectional conversion between multi-agent orchestrator +runtime state and serializable MultiAgentState objects for session persistence. + +Key Features: +- State serialization for Graph and Swarm orchestrators +- State restoration from persisted sessions +- Node result summarization for efficient storage +- Type-safe state conversion with error handling +""" + +import ast +import logging +from typing import Any + +from .multiagent_state import MultiAgentState, MultiAgentType + +logger = logging.getLogger(__name__) + + +class MultiAgentAdapter: + """Adapter for converting between orchestrator runtime state and persistent state. + + This class provides static methods for bidirectional conversion between + multi-agent orchestrator objects (Graph/Swarm) and serializable MultiAgentState. + """ + + @staticmethod + def apply_multi_agent_state(orchestrator: object, multi_agent_state: MultiAgentState): + """Apply persisted state to a multi-agent orchestrator. + + Args: + orchestrator: Graph or Swarm instance to restore state to + multi_agent_state: Persisted state to apply + + Raises: + ValueError: If state type is incompatible with orchestrator + """ + from ...multiagent.base import Status + from ...multiagent.graph import Graph + from ...multiagent.swarm import Swarm + + state_type = getattr(multi_agent_state, "type", None) + type_val = str(getattr(state_type, "value", state_type)) + if isinstance(orchestrator, Graph) and type_val == "graph": + graph = orchestrator + graph.state.status = Status(multi_agent_state.status) + graph.state.completed_nodes = { + graph.nodes[node_id] for node_id in multi_agent_state.completed_nodes if node_id in graph.nodes + } + graph.state.results = { + node_id: result for node_id, result in getattr(multi_agent_state, "node_results", {}).items() + } + execution_node_ids = getattr(multi_agent_state, "execution_order", []) or [] + graph.state.execution_order = [ + graph.nodes[node_id] + for node_id in execution_node_ids + if node_id in graph.nodes and graph.nodes[node_id] in graph.state.completed_nodes + ] + + graph.state.task = getattr(multi_agent_state, "current_task", "") + for node in graph.state.completed_nodes: + node.execution_status = Status.COMPLETED + + return + + elif isinstance(orchestrator, Swarm) and type_val == "swarm": + swarm = orchestrator + + swarm.state.completion_status = Status(multi_agent_state.status) + + swarm.state.node_history = [ + swarm.nodes[nid] for nid in (multi_agent_state.completed_nodes or []) if nid in swarm.nodes + ] + + completed_ids = {n.node_id for n in swarm.state.node_history} + saved_results = getattr(multi_agent_state, "node_results", {}) or {} + swarm.state.results = {k: v for k, v in saved_results.items() if k in completed_ids} + + # current_node + next_ids = getattr(multi_agent_state, "next_node_to_execute", []) or [] + swarm.state.current_node = swarm.nodes.get(next_ids[0]) if next_ids else None + + swarm.state.task = multi_agent_state.current_task + # hydrate context + context = getattr(multi_agent_state, "context", {}) or {} + shared_context = context.get("shared_context") or {} + swarm.shared_context.context = shared_context + swarm.state.handoff_message = context.get("handoff_message") + + else: + raise ValueError("Persisted state type incompatible with current orchestrator") + + @staticmethod + def create_multi_agent_state(orchestrator: object, msg: str = None) -> MultiAgentState | None: + """Create serializable state from multi-agent orchestrator. + + Args: + orchestrator: Graph or Swarm instance to extract state from + msg: Optional error message to include in state + + Returns: + MultiAgentState object ready for persistence, or None if unsupported type + """ + from ...multiagent.base import Status + from ...multiagent.graph import Graph + from ...multiagent.swarm import Swarm + + if isinstance(orchestrator, Graph): + graph = orchestrator + + serialized_results = { + node_id: MultiAgentAdapter.summarize_node_result_for_persist(node_result) + for node_id, node_result in (graph.state.results or {}).items() + } + + inflight = [n.node_id for n in graph.nodes.values() if n.execution_status == Status.EXECUTING] + next_nodes_ids = inflight or [n.node_id for n in graph._compute_ready_nodes_for_resume()] + + return MultiAgentState( + type=MultiAgentType.GRAPH, + status=graph.state.status, + completed_nodes={node.node_id for node in graph.state.completed_nodes}, + node_results=serialized_results, + next_node_to_execute=next_nodes_ids, + current_task=graph.state.task, + error_message=msg, + execution_order=[n.node_id for n in graph.state.execution_order], + context={}, + ) + + elif isinstance(orchestrator, Swarm): + swarm = orchestrator + current_executing_node = ( + [swarm.state.current_node.node_id] if swarm.state.completion_status == Status.EXECUTING else [] + ) + + serialized_results = { + node_id: MultiAgentAdapter.summarize_node_result_for_persist(res) or str(res) + for node_id, res in (swarm.state.results or {}).items() + } + + shared_ctx = {} + if hasattr(swarm, "shared_context") and swarm.shared_context is not None: + shared_ctx = getattr(swarm.shared_context, "context", {}) or {} + + return MultiAgentState( + type=MultiAgentType.SWARM, + status=swarm.state.completion_status, + completed_nodes={node.node_id for node in swarm.state.node_history}, + node_results=serialized_results, + next_node_to_execute=current_executing_node, + current_task=swarm.state.task, + error_message=msg, + execution_order=[node.node_id for node in swarm.state.node_history], + context={ + "shared_context": shared_ctx, + "handoff_message": getattr(swarm.state, "handoff_message", None), + }, + ) + return None + + @staticmethod + def summarize_node_result_for_persist(raw: Any) -> dict[str, Any]: + """Summarize node execution result for efficient persistence. + + Args: + raw: Raw node result (NodeResult, dict, string, or other) + + Returns: + Normalized dict with 'agent_outputs' key containing string results + """ + if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): + try: + results = raw.get_agent_results() + texts = [str(r) for r in results] + return {"agent_outputs": texts} + except Exception: + pass + + # Already a dict + if isinstance(raw, dict): + return MultiAgentAdapter._normalize_persisted_like_dict(raw) + + # String that might itself be a dict represent + if isinstance(raw, str): + try: + parsed = ast.literal_eval(raw) + if isinstance(parsed, dict): + return MultiAgentAdapter._normalize_persisted_like_dict(parsed) + except Exception as e: + logger.debug("Failed to parse persisted node result: %s", e) + return {"agent_outputs": [raw]} + + return {"agent_outputs": [str(raw)]} + + @staticmethod + def _normalize_persisted_like_dict(data: dict[str, Any]) -> dict[str, Any]: + """Normalize dictionary data to standard agent_outputs format. + + Args: + data: Dictionary containing result data + + Returns: + Normalized dict with 'agent_outputs' key + """ + if "agent_outputs" in data and isinstance(data["agent_outputs"], list): + return {"agent_outputs": [str(x) for x in data["agent_outputs"]]} + + if "summary" in data: + return {"agent_outputs": [str(data["summary"])]} + return {"agent_outputs": [str(data)]} diff --git a/src/strands/experimental/multiagent_session/persistence_hooks.py b/src/strands/experimental/multiagent_session/persistence_hooks.py new file mode 100644 index 000000000..565342aef --- /dev/null +++ b/src/strands/experimental/multiagent_session/persistence_hooks.py @@ -0,0 +1,106 @@ +"""Multi-agent session persistence hook implementation. + +This module provides automatic session persistence for multi-agent orchestrators +(Graph and Swarm) by hooking into their execution lifecycle events. + +Key Features: +- Automatic state persistence at key execution points +- Thread-safe persistence operations +- Support for both Graph and Swarm orchestrators +- Seamless integration with SessionManager +""" + +import threading +from typing import Optional + +from ...hooks.registry import HookProvider, HookRegistry +from ...multiagent.base import MultiAgentBase +from ...session import SessionManager +from .multiagent_events import ( + AfterGraphInvocationEvent, + AfterNodeInvocationEvent, + BeforeGraphInvocationEvent, + BeforeNodeInvocationEvent, + MultiAgentInitializationEvent, + MultiAgentState, +) +from .multiagent_state_adapter import MultiAgentAdapter + + +def _get_multiagent_state( + multiagent_state: Optional[MultiAgentState], + orchestrator: MultiAgentBase, +) -> MultiAgentState: + if multiagent_state is not None: + return multiagent_state + + return MultiAgentAdapter.create_multi_agent_state(orchestrator=orchestrator) + + +class MultiAgentHook(HookProvider): + """Hook provider for automatic multi-agent session persistence. + + This hook automatically persists multi-agent orchestrator state at key + execution points to enable resumable execution after interruptions. + + Args: + session_manager: SessionManager instance for state persistence + session_id: Unique identifier for the session + """ + + def __init__(self, session_manager: SessionManager, session_id: str): + """Initialize the multi-agent persistence hook. + + Args: + session_manager: SessionManager instance for state persistence + session_id: Unique identifier for the session + """ + self._session_manager = session_manager + self._session_id = session_id + self._lock = threading.RLock() + + def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: + """Register persistence callbacks for multi-agent execution events. + + Args: + registry: Hook registry to register callbacks with + **kwargs: Additional keyword arguments (unused) + """ + registry.add_callback(MultiAgentInitializationEvent, self._on_initialization) + registry.add_callback(BeforeGraphInvocationEvent, self._on_before_graph) + registry.add_callback(BeforeNodeInvocationEvent, self._on_before_node) + registry.add_callback(AfterNodeInvocationEvent, self._on_after_node) + registry.add_callback(AfterGraphInvocationEvent, self._on_after_graph) + + def _on_initialization(self, event: MultiAgentInitializationEvent): + """Persist state when multi-agent orchestrator initializes.""" + self._persist(_get_multiagent_state(event.state, event.orchestrator)) + + def _on_before_graph(self, event: BeforeGraphInvocationEvent): + """Hook called before graph execution starts.""" + pass + + def _on_before_node(self, event: BeforeNodeInvocationEvent): + """Hook called before individual node execution.""" + pass + + def _on_after_node(self, event: AfterNodeInvocationEvent): + """Persist state after each node completes execution.""" + multi_agent_state = _get_multiagent_state(multiagent_state=event.state, orchestrator=event.orchestrator) + self._persist(multi_agent_state) + + def _on_after_graph(self, event: AfterGraphInvocationEvent): + """Persist final state after graph execution completes.""" + multiagent_state = _get_multiagent_state(multiagent_state=event.state, orchestrator=event.orchestrator) + self._persist(multiagent_state) + + def _persist(self, multiagent_state: MultiAgentState) -> None: + """Persist the provided MultiAgentState using the configured SessionManager. + + This method is synchronized across threads/tasks to avoid write races. + + Args: + multiagent_state: State to persist + """ + with self._lock: + self._session_manager.write_multi_agent_state(multiagent_state) From 49253bc83e5f01e506a68af3db7b7f0cdb4d983c Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Sun, 28 Sep 2025 21:21:49 -0400 Subject: [PATCH 02/27] feat: enable multiagent session persistence # Conflicts: # src/strands/session/s3_session_manager.py --- .../__init__.py | 15 +- .../multiagent_events.py | 45 +- .../persistence_hooks.py | 61 +-- .../multiagent_session/multiagent_state.py | 110 ----- .../multiagent_state_adapter.py | 213 --------- src/strands/multiagent/base.py | 94 +++- src/strands/multiagent/graph.py | 418 +++++++++++++----- src/strands/multiagent/swarm.py | 194 +++++++- src/strands/session/file_session_manager.py | 56 ++- .../session/repository_session_manager.py | 19 +- src/strands/session/s3_session_manager.py | 44 +- src/strands/session/session_manager.py | 45 +- src/strands/types/session.py | 1 + tests/fixtures/mock_session_repository.py | 13 + tests/strands/agent/test_agent.py | 23 +- .../experimental/multiagent_hooks/__init__.py | 0 .../test_multiagent_events.py | 122 +++++ .../test_persistence_hooks.py | 131 ++++++ tests/strands/multiagent/test_base.py | 66 +++ tests/strands/multiagent/test_graph.py | 80 +++- tests/strands/multiagent/test_swarm.py | 50 +++ .../session/test_file_session_manager.py | 42 ++ .../test_repository_session_manager.py | 9 +- .../session/test_s3_session_manager.py | 17 + 24 files changed, 1289 insertions(+), 579 deletions(-) rename src/strands/experimental/{multiagent_session => multiagent_hooks}/__init__.py (59%) rename src/strands/experimental/{multiagent_session => multiagent_hooks}/multiagent_events.py (53%) rename src/strands/experimental/{multiagent_session => multiagent_hooks}/persistence_hooks.py (54%) delete mode 100644 src/strands/experimental/multiagent_session/multiagent_state.py delete mode 100644 src/strands/experimental/multiagent_session/multiagent_state_adapter.py create mode 100644 tests/strands/experimental/multiagent_hooks/__init__.py create mode 100644 tests/strands/experimental/multiagent_hooks/test_multiagent_events.py create mode 100644 tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py diff --git a/src/strands/experimental/multiagent_session/__init__.py b/src/strands/experimental/multiagent_hooks/__init__.py similarity index 59% rename from src/strands/experimental/multiagent_session/__init__.py rename to src/strands/experimental/multiagent_hooks/__init__.py index fbc0e24b4..0567effad 100644 --- a/src/strands/experimental/multiagent_session/__init__.py +++ b/src/strands/experimental/multiagent_hooks/__init__.py @@ -5,22 +5,19 @@ """ from .multiagent_events import ( - AfterGraphInvocationEvent, + AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - BeforeGraphInvocationEvent, + BeforeMultiAgentInvocationEvent, BeforeNodeInvocationEvent, MultiAgentInitializationEvent, ) -from .multiagent_state import MultiAgentState, MultiAgentType -from .multiagent_state_adapter import MultiAgentAdapter +from .persistence_hooks import PersistentHook __all__ = [ - "BeforeGraphInvocationEvent", - "AfterGraphInvocationEvent", + "BeforeMultiAgentInvocationEvent", + "AfterMultiAgentInvocationEvent", "MultiAgentInitializationEvent", "BeforeNodeInvocationEvent", "AfterNodeInvocationEvent", - "MultiAgentState", - "MultiAgentAdapter", - "MultiAgentType", + "PersistentHook", ] diff --git a/src/strands/experimental/multiagent_session/multiagent_events.py b/src/strands/experimental/multiagent_hooks/multiagent_events.py similarity index 53% rename from src/strands/experimental/multiagent_session/multiagent_events.py rename to src/strands/experimental/multiagent_hooks/multiagent_events.py index ebfb9e48f..988799a3f 100644 --- a/src/strands/experimental/multiagent_session/multiagent_events.py +++ b/src/strands/experimental/multiagent_hooks/multiagent_events.py @@ -1,87 +1,82 @@ """Multi-agent execution lifecycle events for hook system integration. -This module defines event classes that are triggered at key points during -multi-agent orchestrator execution, enabling hooks to respond to lifecycle -events for purposes like persistence, monitoring, and debugging. - -Event Types: -- Initialization: When orchestrator starts up -- Before/After Graph: Start/end of overall execution -- Before/After Node: Start/end of individual node execution +These events are fired by orchestrators (Graph/Swarm) at key points so +hooks can persist, monitor, or debug execution. No intermediate state model +is used—hooks read from the orchestrator directly. """ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from ...hooks.registry import HookEvent -from .multiagent_state import MultiAgentState +from ...hooks.registry import BaseHookEvent if TYPE_CHECKING: from ...multiagent.base import MultiAgentBase @dataclass -class MultiAgentInitializationEvent(HookEvent): +class MultiAgentInitializationEvent(BaseHookEvent): """Event triggered when multi-agent orchestrator initializes. Attributes: orchestrator: The multi-agent orchestrator instance - state: Current state of the orchestrator + invocation_state: Configuration that user pass in """ orchestrator: "MultiAgentBase" - state: MultiAgentState + invocation_state: dict[str, Any] | None = None @dataclass -class BeforeGraphInvocationEvent(HookEvent): +class BeforeMultiAgentInvocationEvent(BaseHookEvent): """Event triggered before orchestrator execution begins. Attributes: orchestrator: The multi-agent orchestrator instance - state: Current state before execution starts + invocation_state: Configuration that user pass in """ orchestrator: "MultiAgentBase" - state: MultiAgentState + invocation_state: dict[str, Any] | None = None @dataclass -class BeforeNodeInvocationEvent(HookEvent): +class BeforeNodeInvocationEvent(BaseHookEvent): """Event triggered before individual node execution. Attributes: orchestrator: The multi-agent orchestrator instance - next_node_to_execute: ID of the node about to be executed + invocation_state: Configuration that user pass in """ orchestrator: "MultiAgentBase" next_node_to_execute: str + invocation_state: dict[str, Any] | None = None @dataclass -class AfterNodeInvocationEvent(HookEvent): +class AfterNodeInvocationEvent(BaseHookEvent): """Event triggered after individual node execution completes. Attributes: orchestrator: The multi-agent orchestrator instance executed_node: ID of the node that just completed execution - state: Updated state after node execution + invocation_state: Configuration that user pass in """ orchestrator: "MultiAgentBase" executed_node: str - state: MultiAgentState + invocation_state: dict[str, Any] | None = None @dataclass -class AfterGraphInvocationEvent(HookEvent): +class AfterMultiAgentInvocationEvent(BaseHookEvent): """Event triggered after orchestrator execution completes. Attributes: orchestrator: The multi-agent orchestrator instance - state: Final state after execution completes + invocation_state: Configuration that user pass in """ orchestrator: "MultiAgentBase" - state: MultiAgentState + invocation_state: dict[str, Any] | None = None diff --git a/src/strands/experimental/multiagent_session/persistence_hooks.py b/src/strands/experimental/multiagent_hooks/persistence_hooks.py similarity index 54% rename from src/strands/experimental/multiagent_session/persistence_hooks.py rename to src/strands/experimental/multiagent_hooks/persistence_hooks.py index 565342aef..cbd168e65 100644 --- a/src/strands/experimental/multiagent_session/persistence_hooks.py +++ b/src/strands/experimental/multiagent_hooks/persistence_hooks.py @@ -11,52 +11,37 @@ """ import threading -from typing import Optional +from typing import TYPE_CHECKING from ...hooks.registry import HookProvider, HookRegistry -from ...multiagent.base import MultiAgentBase from ...session import SessionManager from .multiagent_events import ( - AfterGraphInvocationEvent, + AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - BeforeGraphInvocationEvent, + BeforeMultiAgentInvocationEvent, BeforeNodeInvocationEvent, MultiAgentInitializationEvent, - MultiAgentState, ) -from .multiagent_state_adapter import MultiAgentAdapter +if TYPE_CHECKING: + from ...multiagent.base import MultiAgentBase -def _get_multiagent_state( - multiagent_state: Optional[MultiAgentState], - orchestrator: MultiAgentBase, -) -> MultiAgentState: - if multiagent_state is not None: - return multiagent_state - return MultiAgentAdapter.create_multi_agent_state(orchestrator=orchestrator) - - -class MultiAgentHook(HookProvider): +class PersistentHook(HookProvider): """Hook provider for automatic multi-agent session persistence. This hook automatically persists multi-agent orchestrator state at key execution points to enable resumable execution after interruptions. - Args: - session_manager: SessionManager instance for state persistence - session_id: Unique identifier for the session """ - def __init__(self, session_manager: SessionManager, session_id: str): + def __init__(self, session_manager: SessionManager): """Initialize the multi-agent persistence hook. Args: session_manager: SessionManager instance for state persistence - session_id: Unique identifier for the session """ self._session_manager = session_manager - self._session_id = session_id self._lock = threading.RLock() def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: @@ -67,40 +52,40 @@ def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: **kwargs: Additional keyword arguments (unused) """ registry.add_callback(MultiAgentInitializationEvent, self._on_initialization) - registry.add_callback(BeforeGraphInvocationEvent, self._on_before_graph) + registry.add_callback(BeforeMultiAgentInvocationEvent, self._on_before_multiagent) registry.add_callback(BeforeNodeInvocationEvent, self._on_before_node) registry.add_callback(AfterNodeInvocationEvent, self._on_after_node) - registry.add_callback(AfterGraphInvocationEvent, self._on_after_graph) + registry.add_callback(AfterMultiAgentInvocationEvent, self._on_after_multiagent) - def _on_initialization(self, event: MultiAgentInitializationEvent): + # TODO: We can add **kwarg or invocation_state later if we need to persist + def _on_initialization(self, event: MultiAgentInitializationEvent) -> None: """Persist state when multi-agent orchestrator initializes.""" - self._persist(_get_multiagent_state(event.state, event.orchestrator)) + self._persist(event.orchestrator) - def _on_before_graph(self, event: BeforeGraphInvocationEvent): - """Hook called before graph execution starts.""" + def _on_before_multiagent(self, event: BeforeMultiAgentInvocationEvent) -> None: + """Persist state when multi-agent orchestrator initializes.""" pass - def _on_before_node(self, event: BeforeNodeInvocationEvent): + def _on_before_node(self, event: BeforeNodeInvocationEvent) -> None: """Hook called before individual node execution.""" pass - def _on_after_node(self, event: AfterNodeInvocationEvent): + def _on_after_node(self, event: AfterNodeInvocationEvent) -> None: """Persist state after each node completes execution.""" - multi_agent_state = _get_multiagent_state(multiagent_state=event.state, orchestrator=event.orchestrator) - self._persist(multi_agent_state) + self._persist(event.orchestrator) - def _on_after_graph(self, event: AfterGraphInvocationEvent): + def _on_after_multiagent(self, event: AfterMultiAgentInvocationEvent) -> None: """Persist final state after graph execution completes.""" - multiagent_state = _get_multiagent_state(multiagent_state=event.state, orchestrator=event.orchestrator) - self._persist(multiagent_state) + self._persist(event.orchestrator) - def _persist(self, multiagent_state: MultiAgentState) -> None: + def _persist(self, orchestrator: "MultiAgentBase") -> None: """Persist the provided MultiAgentState using the configured SessionManager. This method is synchronized across threads/tasks to avoid write races. Args: - multiagent_state: State to persist + orchestrator: State to persist """ + current_state = orchestrator.get_state_from_orchestrator() with self._lock: - self._session_manager.write_multi_agent_state(multiagent_state) + self._session_manager.write_multi_agent_json(current_state) diff --git a/src/strands/experimental/multiagent_session/multiagent_state.py b/src/strands/experimental/multiagent_session/multiagent_state.py deleted file mode 100644 index 60bbef6cb..000000000 --- a/src/strands/experimental/multiagent_session/multiagent_state.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Multi-agent state data structures for session persistence. - -This module defines the core data structures used to represent the state -of multi-agent orchestrators in a serializable format for session persistence. - -Key Components: -- MultiAgentType: Enum for orchestrator types (Graph/Swarm) -- MultiAgentState: Serializable state container with conversion methods -""" - -from dataclasses import dataclass, field -from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set - -from ...types.content import ContentBlock - -if TYPE_CHECKING: - from ...multiagent.base import Status - - -# TODO: Move to Base after experimental -class MultiAgentType(Enum): - """Enumeration of supported multi-agent orchestrator types. - - Attributes: - SWARM: Collaborative agent swarm orchestrator - GRAPH: Directed graph-based agent orchestrator - """ - - SWARM = "swarm" - GRAPH = "graph" - - -@dataclass -class MultiAgentState: - """Serializable state container for multi-agent orchestrators. - - This class represents the complete execution state of a multi-agent - orchestrator (Graph or Swarm) in a format suitable for persistence - and restoration across sessions. - - Attributes: - completed_nodes: Set of node IDs that have completed execution - node_results: Dictionary mapping node IDs to their execution results - status: Current execution status of the orchestrator - next_node_to_execute: List of node IDs ready for execution - current_task: The original task being executed - execution_order: Ordered list of executed node IDs - error_message: Optional error message if execution failed - type: Type of orchestrator (Graph or Swarm) - context: Additional context data (primarily for Swarm) - """ - - # Mutual - completed_nodes: Set[str] = field(default_factory=set) - node_results: Dict[str, Any] = field(default_factory=dict) - status: "Status" = "pending" - next_node_to_execute: Optional[List[str]] = None - current_task: Optional[str | List[ContentBlock]] = None - execution_order: list[str] = field(default_factory=list) - error_message: Optional[str] = None - type: Optional[MultiAgentType] = field(default=MultiAgentType.GRAPH) - # Swarm - context: Optional[dict] = field(default_factory=dict) - - def to_dict(self) -> dict[str, Any]: - """Convert MultiAgentState to JSON-serializable dictionary. - - Returns: - Dictionary representation suitable for JSON serialization - """ - - def _serialize(v: Any) -> Any: - if isinstance(v, (str, int, float, bool)) or v is None: - return v - if isinstance(v, set): - return list(v) - if isinstance(v, dict): - return {str(k): _serialize(val) for k, val in v.items()} - if isinstance(v, (list, tuple)): - return [_serialize(x) for x in v] - if hasattr(v, "to_dict"): - return v.to_dict() - # last resort: stringize anything non-serializable (locks, objects, etc.) - return str(v) - - return { - "status": self.status, - "completed_nodes": list(self.completed_nodes), - "next_node_to_execute": list(self.next_node_to_execute) if self.next_node_to_execute else [], - "node_results": _serialize(self.node_results), - "current_task": self.current_task, - "error_message": self.error_message, - "execution_order": self.execution_order, - "type": self.type, - "context": _serialize(self.context), - } - - @classmethod - def from_dict(cls, data: dict): - """Create MultiAgentState from dictionary data. - - Args: - data: Dictionary containing state data - - Returns: - MultiAgentState instance - """ - data["completed_nodes"] = set(data.get("completed_nodes", [])) - return cls(**data) diff --git a/src/strands/experimental/multiagent_session/multiagent_state_adapter.py b/src/strands/experimental/multiagent_session/multiagent_state_adapter.py deleted file mode 100644 index 6d9cbb145..000000000 --- a/src/strands/experimental/multiagent_session/multiagent_state_adapter.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Multi-agent state adapter for session persistence. - -This module provides bidirectional conversion between multi-agent orchestrator -runtime state and serializable MultiAgentState objects for session persistence. - -Key Features: -- State serialization for Graph and Swarm orchestrators -- State restoration from persisted sessions -- Node result summarization for efficient storage -- Type-safe state conversion with error handling -""" - -import ast -import logging -from typing import Any - -from .multiagent_state import MultiAgentState, MultiAgentType - -logger = logging.getLogger(__name__) - - -class MultiAgentAdapter: - """Adapter for converting between orchestrator runtime state and persistent state. - - This class provides static methods for bidirectional conversion between - multi-agent orchestrator objects (Graph/Swarm) and serializable MultiAgentState. - """ - - @staticmethod - def apply_multi_agent_state(orchestrator: object, multi_agent_state: MultiAgentState): - """Apply persisted state to a multi-agent orchestrator. - - Args: - orchestrator: Graph or Swarm instance to restore state to - multi_agent_state: Persisted state to apply - - Raises: - ValueError: If state type is incompatible with orchestrator - """ - from ...multiagent.base import Status - from ...multiagent.graph import Graph - from ...multiagent.swarm import Swarm - - state_type = getattr(multi_agent_state, "type", None) - type_val = str(getattr(state_type, "value", state_type)) - if isinstance(orchestrator, Graph) and type_val == "graph": - graph = orchestrator - graph.state.status = Status(multi_agent_state.status) - graph.state.completed_nodes = { - graph.nodes[node_id] for node_id in multi_agent_state.completed_nodes if node_id in graph.nodes - } - graph.state.results = { - node_id: result for node_id, result in getattr(multi_agent_state, "node_results", {}).items() - } - execution_node_ids = getattr(multi_agent_state, "execution_order", []) or [] - graph.state.execution_order = [ - graph.nodes[node_id] - for node_id in execution_node_ids - if node_id in graph.nodes and graph.nodes[node_id] in graph.state.completed_nodes - ] - - graph.state.task = getattr(multi_agent_state, "current_task", "") - for node in graph.state.completed_nodes: - node.execution_status = Status.COMPLETED - - return - - elif isinstance(orchestrator, Swarm) and type_val == "swarm": - swarm = orchestrator - - swarm.state.completion_status = Status(multi_agent_state.status) - - swarm.state.node_history = [ - swarm.nodes[nid] for nid in (multi_agent_state.completed_nodes or []) if nid in swarm.nodes - ] - - completed_ids = {n.node_id for n in swarm.state.node_history} - saved_results = getattr(multi_agent_state, "node_results", {}) or {} - swarm.state.results = {k: v for k, v in saved_results.items() if k in completed_ids} - - # current_node - next_ids = getattr(multi_agent_state, "next_node_to_execute", []) or [] - swarm.state.current_node = swarm.nodes.get(next_ids[0]) if next_ids else None - - swarm.state.task = multi_agent_state.current_task - # hydrate context - context = getattr(multi_agent_state, "context", {}) or {} - shared_context = context.get("shared_context") or {} - swarm.shared_context.context = shared_context - swarm.state.handoff_message = context.get("handoff_message") - - else: - raise ValueError("Persisted state type incompatible with current orchestrator") - - @staticmethod - def create_multi_agent_state(orchestrator: object, msg: str = None) -> MultiAgentState | None: - """Create serializable state from multi-agent orchestrator. - - Args: - orchestrator: Graph or Swarm instance to extract state from - msg: Optional error message to include in state - - Returns: - MultiAgentState object ready for persistence, or None if unsupported type - """ - from ...multiagent.base import Status - from ...multiagent.graph import Graph - from ...multiagent.swarm import Swarm - - if isinstance(orchestrator, Graph): - graph = orchestrator - - serialized_results = { - node_id: MultiAgentAdapter.summarize_node_result_for_persist(node_result) - for node_id, node_result in (graph.state.results or {}).items() - } - - inflight = [n.node_id for n in graph.nodes.values() if n.execution_status == Status.EXECUTING] - next_nodes_ids = inflight or [n.node_id for n in graph._compute_ready_nodes_for_resume()] - - return MultiAgentState( - type=MultiAgentType.GRAPH, - status=graph.state.status, - completed_nodes={node.node_id for node in graph.state.completed_nodes}, - node_results=serialized_results, - next_node_to_execute=next_nodes_ids, - current_task=graph.state.task, - error_message=msg, - execution_order=[n.node_id for n in graph.state.execution_order], - context={}, - ) - - elif isinstance(orchestrator, Swarm): - swarm = orchestrator - current_executing_node = ( - [swarm.state.current_node.node_id] if swarm.state.completion_status == Status.EXECUTING else [] - ) - - serialized_results = { - node_id: MultiAgentAdapter.summarize_node_result_for_persist(res) or str(res) - for node_id, res in (swarm.state.results or {}).items() - } - - shared_ctx = {} - if hasattr(swarm, "shared_context") and swarm.shared_context is not None: - shared_ctx = getattr(swarm.shared_context, "context", {}) or {} - - return MultiAgentState( - type=MultiAgentType.SWARM, - status=swarm.state.completion_status, - completed_nodes={node.node_id for node in swarm.state.node_history}, - node_results=serialized_results, - next_node_to_execute=current_executing_node, - current_task=swarm.state.task, - error_message=msg, - execution_order=[node.node_id for node in swarm.state.node_history], - context={ - "shared_context": shared_ctx, - "handoff_message": getattr(swarm.state, "handoff_message", None), - }, - ) - return None - - @staticmethod - def summarize_node_result_for_persist(raw: Any) -> dict[str, Any]: - """Summarize node execution result for efficient persistence. - - Args: - raw: Raw node result (NodeResult, dict, string, or other) - - Returns: - Normalized dict with 'agent_outputs' key containing string results - """ - if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): - try: - results = raw.get_agent_results() - texts = [str(r) for r in results] - return {"agent_outputs": texts} - except Exception: - pass - - # Already a dict - if isinstance(raw, dict): - return MultiAgentAdapter._normalize_persisted_like_dict(raw) - - # String that might itself be a dict represent - if isinstance(raw, str): - try: - parsed = ast.literal_eval(raw) - if isinstance(parsed, dict): - return MultiAgentAdapter._normalize_persisted_like_dict(parsed) - except Exception as e: - logger.debug("Failed to parse persisted node result: %s", e) - return {"agent_outputs": [raw]} - - return {"agent_outputs": [str(raw)]} - - @staticmethod - def _normalize_persisted_like_dict(data: dict[str, Any]) -> dict[str, Any]: - """Normalize dictionary data to standard agent_outputs format. - - Args: - data: Dictionary containing result data - - Returns: - Normalized dict with 'agent_outputs' key - """ - if "agent_outputs" in data and isinstance(data["agent_outputs"], list): - return {"agent_outputs": [str(x) for x in data["agent_outputs"]]} - - if "summary" in data: - return {"agent_outputs": [str(data["summary"])]} - return {"agent_outputs": [str(data)]} diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 03d7de9b4..dcebbe5da 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -4,6 +4,7 @@ """ import asyncio +import logging from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field @@ -14,6 +15,8 @@ from ..types.content import ContentBlock from ..types.event_loop import Metrics, Usage +logger = logging.getLogger(__name__) + class Status(Enum): """Execution status for both graphs and nodes.""" @@ -34,7 +37,7 @@ class NodeResult: """ # Core result data - single AgentResult, nested MultiAgentResult, or Exception - result: Union[AgentResult, "MultiAgentResult", Exception] + result: Union["AgentResult", "MultiAgentResult", Exception] # Execution metadata execution_time: int = 0 @@ -45,19 +48,33 @@ class NodeResult: accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_count: int = 0 - def get_agent_results(self) -> list[AgentResult]: + def get_agent_results(self) -> list["AgentResult"]: """Get all AgentResult objects from this node, flattened if nested.""" if isinstance(self.result, Exception): return [] # No agent results for exceptions elif isinstance(self.result, AgentResult): return [self.result] - else: - # Flatten nested results from MultiAgentResult - flattened = [] - for nested_node_result in self.result.results.values(): - flattened.extend(nested_node_result.get_agent_results()) + # else: + # # Flatten nested results from MultiAgentResult + # flattened = [] + # for nested_node_result in self.result.results.values(): + # if isinstance(nested_node_result, NodeResult): + # flattened.extend(nested_node_result.get_agent_results()) + # return flattened + + if getattr(self.result, "__class__", None) and self.result.__class__.__name__ == "AgentResult": + return [self.result] # type: ignore[list-item] + + # If this is a nested MultiAgentResult, flatten children + if hasattr(self.result, "results") and isinstance(self.result.results, dict): + flattened: list["AgentResult"] = [] + for nested in self.result.results.values(): + if isinstance(nested, NodeResult): + flattened.extend(nested.get_agent_results()) return flattened + return [] + @dataclass class MultiAgentResult: @@ -117,3 +134,66 @@ def execute() -> MultiAgentResult: with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() + + def _call_hook_safely(self, event_object: object) -> None: + """Invoke hook callbacks and swallow hook errors. + + Args: + event_object: The event to dispatch to registered callbacks. + """ + try: + self.hooks.invoke_callbacks(event_object) # type: ignore + except Exception as e: + logger.exception("Hook invocation failed for %s: %s", type(event_object).__name__, e) + + @abstractmethod + def get_state_from_orchestrator(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state.""" + raise NotImplementedError + + @abstractmethod + def apply_state_from_dict(self, payload: dict) -> None: + """Restore orchestrator state from a session dict.""" + raise NotImplementedError + + def summarize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: + """Summarize node result for efficient persistence. + + Args: + raw: Raw node result to summarize + + Returns: + Normalized dict with 'agent_outputs' key containing string results + """ + + def _extract_text_from_agent_result(ar: Any) -> str: + try: + msg = getattr(ar, "message", None) + if isinstance(msg, dict): + blocks = msg.get("content") or [] + texts = [] + for b in blocks: + t = b.get("text") + if t: + texts.append(str(t)) + if texts: + return "\n".join(texts) + return str(ar) + except Exception: + return str(ar) + + # If it's a NodeResult with AgentResults, flatten + if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): + try: + ars = raw.get_agent_results() # list[AgentResult] + if ars: + return {"agent_outputs": [_extract_text_from_agent_result(a) for a in ars]} + except Exception: + pass + + # If already normalized + if isinstance(raw, dict) and isinstance(raw.get("agent_outputs"), list): + return {"agent_outputs": [str(x) for x in raw["agent_outputs"]]} + + # Fallback + return {"agent_outputs": [str(raw)]} diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 738dc4d4c..9f66c30ff 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -14,6 +14,7 @@ - Supports nested graphs (Graph as a node in another Graph) """ +import ast import asyncio import copy import logging @@ -26,6 +27,15 @@ from ..agent import Agent from ..agent.state import AgentState +from ..experimental.multiagent_hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeMultiAgentInvocationEvent, + MultiAgentInitializationEvent, +) +from ..experimental.multiagent_hooks.persistence_hooks import PersistentHook +from ..hooks import HookRegistry +from ..session import SessionManager from ..telemetry import get_tracer from ..types.content import ContentBlock, Messages from ..types.event_loop import Metrics, Usage @@ -93,6 +103,23 @@ def should_continue( return True, "Continuing" + @staticmethod + def _normalize_persisted_like_dict(data: dict[str, Any]) -> dict[str, Any]: + """Normalize dictionary data to standard agent_outputs format. + + Args: + data: Dictionary containing result data + + Returns: + Normalized dict with 'agent_outputs' key + """ + if "agent_outputs" in data and isinstance(data["agent_outputs"], list): + return {"agent_outputs": [str(x) for x in data["agent_outputs"]]} + + if "summary" in data: + return {"agent_outputs": [str(data["summary"])]} + return {"agent_outputs": [str(data)]} + @dataclass class GraphResult(MultiAgentResult): @@ -195,12 +222,6 @@ def _validate_node_executor( if id(executor) in seen_instances: raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.") - # Validate Agent-specific constraints - if isinstance(executor, Agent): - # Check for session persistence - if executor._session_manager is not None: - raise ValueError("Session persistence is not supported for Graph agents yet.") - class GraphBuilder: """Builder pattern for constructing graphs.""" @@ -217,6 +238,9 @@ def __init__(self) -> None: self._node_timeout: Optional[float] = None self._reset_on_revisit: bool = False + # session manager + self._session_manager: Optional[SessionManager] = None + def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" _validate_node_executor(executor, self.nodes) @@ -306,6 +330,15 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder": self._node_timeout = timeout return self + def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder": + """Set session manager for the graph. + + Args: + session_manager: SessionManager instance + """ + self._session_manager = session_manager + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -323,6 +356,11 @@ def build(self) -> "Graph": # Validate entry points and check for cycles self._validate_graph() + # Build hook auto-inject + hooks = HookRegistry() + if self._session_manager is not None: + hooks.add_hook(PersistentHook(session_manager=self._session_manager)) + return Graph( nodes=self.nodes.copy(), edges=self.edges.copy(), @@ -331,6 +369,8 @@ def build(self) -> "Graph": execution_timeout=self._execution_timeout, node_timeout=self._node_timeout, reset_on_revisit=self._reset_on_revisit, + session_manager=self._session_manager, + hooks=hooks, ) def _validate_graph(self) -> None: @@ -346,6 +386,33 @@ def _validate_graph(self) -> None: logger.warning("Graph without execution limits may run indefinitely if cycles exist") +def _iterate_previous_outputs(raw: Any) -> list[tuple[str, str]]: + """Return a list of (agent_name, text) from NodeResult or persisted dict.""" + # Live NodeResult + if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): + try: + return [(getattr(r, "agent_name", "Agent"), str(r)) for r in raw.get_agent_results()] + except Exception: + pass + + # Convert string to dict if possible + data = raw + if isinstance(raw, str): + try: + data = ast.literal_eval(raw) + except Exception: + return [("Agent", raw)] + + # Extract from dict + if isinstance(data, dict): + if isinstance(data.get("agent_outputs"), list): + return [("Agent", str(x)) for x in data["agent_outputs"]] + if "summary" in data: + return [("Agent", str(data["summary"]))] + + return [("Agent", str(raw))] + + class Graph(MultiAgentBase): """Directed Graph multi-agent orchestration with configurable revisit behavior.""" @@ -358,6 +425,8 @@ def __init__( execution_timeout: Optional[float] = None, node_timeout: Optional[float] = None, reset_on_revisit: bool = False, + session_manager: Optional[SessionManager] = None, + hooks: Optional[HookRegistry] = None, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -369,6 +438,8 @@ def __init__( execution_timeout: Total execution timeout in seconds (default: None - no limit) node_timeout: Individual node timeout in seconds (default: None - no limit) reset_on_revisit: Whether to reset node state when revisited (default: False) + session_manager: Optional session manager for persistence + hooks: Optional hook registry for event handling """ super().__init__() @@ -384,6 +455,16 @@ def __init__( self.reset_on_revisit = reset_on_revisit self.state = GraphState() self.tracer = get_tracer() + self.session_manager = session_manager + self.hooks = hooks or HookRegistry() + + # Concurrncy lock + self._lock = asyncio.Lock() + # Resume flag + self._resume_from_persisted = False + self._resume_next_nodes: list[GraphNode] = [] + + self._load_and_apply_persisted_state() def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -423,17 +504,30 @@ async def invoke_async( logger.debug("task=<%s> | starting graph execution", task) - # Initialize state - start_time = time.time() - self.state = GraphState( - status=Status.EXECUTING, - task=task, - total_nodes=len(self.nodes), - edges=[(edge.from_node, edge.to_node) for edge in self.edges], - entry_points=list(self.entry_points), - start_time=start_time, - ) - + if not self._resume_from_persisted and self.state.status == Status.PENDING: + # TODO: to check if we need do something on BeforeGraphInvocationEvent + self._call_hook_safely(MultiAgentInitializationEvent(orchestrator=self)) + + self._call_hook_safely(BeforeMultiAgentInvocationEvent(orchestrator=self)) + + if not self._resume_from_persisted: + start_time = time.time() + # Initialize state + self.state = GraphState( + status=Status.EXECUTING, + task=task, + total_nodes=len(self.nodes), + edges=[(edge.from_node, edge.to_node) for edge in self.edges], + entry_points=sorted(self.entry_points, key=lambda node: node.node_id), + start_time=start_time, + ) + else: + if isinstance(self.state.task, (str, list)) and not self.state.task: + self.state.task = task + # Reset failed nodes after resume. + self.state.status = Status.EXECUTING + self.state.failed_nodes.clear() + self.state.start_time = time.time() span = self.tracer.start_multiagent_span(task, "graph") with trace_api.use_span(span, end_on_exit=True): try: @@ -459,7 +553,10 @@ async def invoke_async( self.state.status = Status.FAILED raise finally: - self.state.execution_time = round((time.time() - start_time) * 1000) + self.state.execution_time = round((time.time() - self.state.start_time) * 1000) + self._call_hook_safely(AfterMultiAgentInvocationEvent(orchestrator=self)) + self._resume_from_persisted = False + self._resume_next_nodes.clear() return self._build_result() def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: @@ -476,7 +573,11 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" - ready_nodes = list(self.entry_points) + ready_nodes = ( + sorted(self._resume_next_nodes, key=lambda n: n.node_id) + if self._resume_from_persisted + else sorted(self.entry_points, key=lambda n: n.node_id) + ) while ready_nodes: # Check execution limits before continuing @@ -506,9 +607,18 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = [] - for _node_id, node in self.nodes.items(): + for node in self.nodes.values(): + # Skip nodes already completed unless we’re in feedback-loop mode + if ( + node in self.state.completed_nodes or node.execution_status == Status.COMPLETED + ) and not self.reset_on_revisit: + continue + if node in self.state.failed_nodes: + continue if self._is_node_ready_with_conditions(node, completed_batch): - newly_ready.append(node) + # Avoid duplicates + if node not in newly_ready: + newly_ready.append(node) return newly_ready def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool: @@ -537,7 +647,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() # Remove from completed nodes since we're re-executing it - self.state.completed_nodes.remove(node) + self.state.completed_nodes.discard(node) node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) @@ -547,88 +657,76 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) # Build node input from satisfied dependencies node_input = self._build_node_input(node) - # Execute with timeout protection (only if node_timeout is set) - try: - # Execute based on node type and create unified NodeResult - if isinstance(node.executor, MultiAgentBase): - if self.node_timeout is not None: - multi_agent_result = await asyncio.wait_for( - node.executor.invoke_async(node_input, invocation_state), - timeout=self.node_timeout, - ) - else: - multi_agent_result = await node.executor.invoke_async(node_input, invocation_state) - - # Create NodeResult with MultiAgentResult directly - node_result = NodeResult( - result=multi_agent_result, # type is MultiAgentResult - execution_time=multi_agent_result.execution_time, - status=Status.COMPLETED, - accumulated_usage=multi_agent_result.accumulated_usage, - accumulated_metrics=multi_agent_result.accumulated_metrics, - execution_count=multi_agent_result.execution_count, + # Execute based on node type and create unified NodeResult + if isinstance(node.executor, MultiAgentBase): + if self.node_timeout is not None: + multiagent_result = await asyncio.wait_for( + node.executor.invoke_async(node_input, invocation_state), + timeout=self.node_timeout, ) + else: + multiagent_result = await node.executor.invoke_async(node_input, invocation_state) + + # Create NodeResult with MultiAgentResult directly + node_result = NodeResult( + result=multiagent_result, # type is MultiAgentResult + execution_time=multiagent_result.execution_time, + status=Status.COMPLETED, + accumulated_usage=multiagent_result.accumulated_usage, + accumulated_metrics=multiagent_result.accumulated_metrics, + execution_count=multiagent_result.execution_count, + ) - elif isinstance(node.executor, Agent): - if self.node_timeout is not None: - agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, **invocation_state), - timeout=self.node_timeout, - ) - else: - agent_response = await node.executor.invoke_async(node_input, **invocation_state) - - # Extract metrics from agent response - usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) - metrics = Metrics(latencyMs=0) - if hasattr(agent_response, "metrics") and agent_response.metrics: - if hasattr(agent_response.metrics, "accumulated_usage"): - usage = agent_response.metrics.accumulated_usage - if hasattr(agent_response.metrics, "accumulated_metrics"): - metrics = agent_response.metrics.accumulated_metrics - - node_result = NodeResult( - result=agent_response, # type is AgentResult - execution_time=round((time.time() - start_time) * 1000), - status=Status.COMPLETED, - accumulated_usage=usage, - accumulated_metrics=metrics, - execution_count=1, + elif isinstance(node.executor, Agent): + if self.node_timeout is not None: + agent_response = await asyncio.wait_for( + node.executor.invoke_async(node_input, **invocation_state), + timeout=self.node_timeout, ) else: - raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - - except asyncio.TimeoutError: - timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" - logger.exception( - "node=<%s>, timeout=<%s>s | node execution timed out after timeout", - node.node_id, - self.node_timeout, + agent_response = await node.executor.invoke_async(node_input, **invocation_state) + + # Extract metrics from agent response + usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) + metrics = Metrics(latencyMs=0) + if hasattr(agent_response, "metrics") and agent_response.metrics: + if hasattr(agent_response.metrics, "accumulated_usage"): + usage = agent_response.metrics.accumulated_usage + if hasattr(agent_response.metrics, "accumulated_metrics"): + metrics = agent_response.metrics.accumulated_metrics + + node_result = NodeResult( + result=agent_response, # type is AgentResult + execution_time=round((time.time() - start_time) * 1000), + status=Status.COMPLETED, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=1, ) - raise Exception(timeout_msg) from None - - # Mark as completed - node.execution_status = Status.COMPLETED - node.result = node_result - node.execution_time = node_result.execution_time - self.state.completed_nodes.add(node) - self.state.results[node.node_id] = node_result - self.state.execution_order.append(node) - - # Accumulate metrics - self._accumulate_metrics(node_result) + else: + raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") + + async with self._lock: + node.execution_status = Status.COMPLETED + node.result = node_result + node.execution_time = node_result.execution_time + self.state.completed_nodes.add(node) + self.state.results[node.node_id] = node_result + self.state.execution_order.append(node) + # Accumulate metrics + self._accumulate_metrics(node_result) logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time ) + self._call_hook_safely(AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id)) except Exception as e: logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e) execution_time = round((time.time() - start_time) * 1000) - # Create a NodeResult for the failed node - node_result = NodeResult( - result=e, # Store exception as result + fail_result = NodeResult( + result=e, execution_time=execution_time, status=Status.FAILED, accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), @@ -636,11 +734,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) execution_count=1, ) - node.execution_status = Status.FAILED - node.result = node_result - node.execution_time = execution_time - self.state.failed_nodes.add(node) - self.state.results[node.node_id] = node_result # Store in results for consistency + async with self._lock: + node.execution_status = Status.FAILED + node.result = fail_result + node.execution_time = execution_time + self.state.failed_nodes.add(node) + self.state.results[node.node_id] = fail_result + + # Need to persist failure multiagent_state too + self._call_hook_safely(AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id)) raise @@ -688,7 +790,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: return self.state.task # Combine task with dependency outputs - node_input = [] + node_input: list[ContentBlock] = [] # Add original task if isinstance(self.state.task, str): @@ -701,15 +803,10 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: # Add dependency outputs node_input.append(ContentBlock(text="\nInputs from previous nodes:")) - for dep_id, node_result in dependency_results.items(): + for dep_id, prev_result in dependency_results.items(): node_input.append(ContentBlock(text=f"\nFrom {dep_id}:")) - # Get all agent results from this node (flattened if nested) - agent_results = node_result.get_agent_results() - for result in agent_results: - agent_name = getattr(result, "agent_name", "Agent") - result_text = str(result) + for agent_name, result_text in _iterate_previous_outputs(prev_result): node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}")) - return node_input def _build_result(self) -> GraphResult: @@ -728,3 +825,118 @@ def _build_result(self) -> GraphResult: edges=self.state.edges, entry_points=self.state.entry_points, ) + + # Persistence Helper Functions + + def _from_dict(self, payload: dict[str, Any]) -> None: + status_raw = payload.get("status", "pending") + try: + self.state.status = Status(status_raw) + except Exception: + self.state.status = Status.PENDING + + # Hydrate completed nodes & results + completed_node_ids = payload.get("completed_nodes") or [] + self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} + + self.state.results = dict(payload.get("node_results") or {}) + # Execution order (only nodes that still exist) + order_node_ids = payload.get("execution_order") or [] + self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes] + + # Task + self.state.task = payload.get("current_task", self.state.task) + + def _to_dict(self) -> dict[str, Any]: + status_str = getattr(self.state.status, "value", str(self.state.status)) + next_nodes = [n.node_id for n in self._compute_ready_nodes_for_resume()] + return { + "type": "graph", + "status": status_str, + "completed_nodes": [n.node_id for n in self.state.completed_nodes], + "node_results": { + k: self.summarize_node_result_for_persist(v) # CHANGED: normalized text-only outputs + for k, v in (self.state.results or {}).items() + }, + "next_node_to_execute": next_nodes, + "current_task": self.state.task, + "execution_order": [n.node_id for n in self.state.execution_order], + } + + def _load_and_apply_persisted_state(self) -> None: + if self.session_manager is None: + return + try: + json_data = self.session_manager.read_multi_agent_json() + except Exception as e: + logger.warning("Skipping resume; failed to load state: %s", e) + return + if not json_data: + return + + try: + self.apply_state_from_dict(json_data) + self._resume_from_persisted = True + + next_node_ids = json_data.get("next_node_to_execute") or [] + mapped = self._map_node_ids(next_node_ids) + valid_ready: list[GraphNode] = [] + completed = set(self.state.completed_nodes) + + for node in mapped: + if node in completed or node.execution_status == Status.COMPLETED: + continue + # only include if it’s dependency-ready + incoming = [edge for edge in self.edges if edge.to_node == node] + if any(edge.from_node in completed and edge.should_traverse(self.state) for edge in incoming): + valid_ready.append(node) + + if not valid_ready: + valid_ready = self._compute_ready_nodes_for_resume() + + self._resume_next_nodes = sorted(valid_ready, key=lambda node: node.node_id) + logger.debug("Resumed from persisted state. Next nodes: %s", [n.node_id for n in self._resume_next_nodes]) + except Exception as e: + logger.exception("Failed to apply multiagent state : %s", e) + + def _map_node_ids(self, node_ids: list[str] | None) -> list[GraphNode]: + if not node_ids: + return [] + mapped_nodes = [] + for node_id in node_ids: + node = self.nodes.get(node_id) + if node: + mapped_nodes.append(node) + return mapped_nodes + + def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: + ready_nodes: list[GraphNode] = [] + completed_nodes = set(self.state.completed_nodes) + + for node in self.nodes.values(): + if node in completed_nodes: + continue + incoming = [e for e in self.edges if e.to_node is node] + if any(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): + ready_nodes.append(node) + + if ready_nodes: + return ready_nodes + + return [node for node in self.entry_points if node not in completed_nodes] + + def get_state_from_orchestrator(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state. + + Returns: + Dictionary containing the current graph state + """ + return self._to_dict() + + def apply_state_from_dict(self, payload: dict) -> None: + """Restore orchestrator state from a session dict. + + Args: + payload: Dictionary containing persisted state data + """ + self._from_dict(payload) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 620fa5e24..c24449e6b 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -25,6 +25,16 @@ from ..agent import Agent, AgentResult from ..agent.state import AgentState +from ..experimental.multiagent_hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeInvocationEvent, + MultiAgentInitializationEvent, +) +from ..experimental.multiagent_hooks.persistence_hooks import PersistentHook +from ..hooks import HookRegistry +from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool from ..types.content import ContentBlock, Messages @@ -203,6 +213,8 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, + session_manager: SessionManager | None = None, + hooks: HookRegistry | None = None, ) -> None: """Initialize Swarm with agents and configuration. @@ -217,6 +229,8 @@ def __init__( Disabled by default (default: 0) repetitive_handoff_min_unique_agents: Minimum unique agents required in recent sequence Disabled by default (default: 0) + session_manager: Optional session manager for persistence + hooks: Optional hook registry for event handling """ super().__init__() @@ -237,9 +251,20 @@ def __init__( ) self.tracer = get_tracer() + self.session_manager = session_manager + self.hooks = hooks or HookRegistry() + if self.session_manager is not None: + self.hooks.add_hook(PersistentHook(session_manager=self.session_manager)) + self._setup_swarm(nodes) self._inject_swarm_tools() + # We need flags here from Graph since they have different mechanism to determine end of loop. + self._resume_from_persisted = False + self._resume_from_completed = False + + self._load_and_apply_persisted_state() + def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any ) -> SwarmResult: @@ -278,18 +303,31 @@ async def invoke_async( logger.debug("starting swarm execution") - # Initialize swarm state with configuration - if self.entry_point: - initial_node = self.nodes[str(self.entry_point.name)] - else: - initial_node = next(iter(self.nodes.values())) # First SwarmNode + if self._resume_from_persisted and self._resume_from_completed: + logger.debug("Returning persisted COMPLETED result without re-execution.") + return self._build_result() - self.state = SwarmState( - current_node=initial_node, - task=task, - completion_status=Status.EXECUTING, - shared_context=self.shared_context, - ) + if not self._resume_from_persisted and self.state.completion_status == Status.PENDING: + self._call_hook_safely(MultiAgentInitializationEvent(orchestrator=self)) + + # Customizable Before GraphInvocation. + self._call_hook_safely(BeforeMultiAgentInvocationEvent(orchestrator=self)) + + # If resume + if not self._resume_from_persisted: + initial_node = self._initial_node() + + self.state = SwarmState( + current_node=initial_node, + task=task, + completion_status=Status.EXECUTING, + shared_context=self.shared_context, + ) + else: + if isinstance(self.state.task, (str, list)) and not self.state.task: + self.state.task = task + self.state.completion_status = Status.EXECUTING + self.state.start_time = time.time() start_time = time.time() span = self.tracer.start_multiagent_span(task, "swarm") @@ -310,6 +348,9 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) + self._call_hook_safely(AfterMultiAgentInvocationEvent(orchestrator=self)) + self._resume_from_persisted = False + self._resume_from_completed = False return self._build_result() @@ -458,6 +499,12 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st previous_agent.node_id, target_node.node_id, ) + # Persist handoff msg incase we lose it. + if self.session_manager is not None: + try: + self.session_manager.write_multi_agent_json(self.get_state_from_orchestrator()) + except Exception as e: + logger.warning("Failed to persist swarm state after handoff: %s", e) def _build_node_input(self, target_node: SwarmNode) -> str: """Build input text for a node based on shared context and handoffs. @@ -562,6 +609,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: self.state.completion_status = Status.FAILED break + # TODO: BEFORE NODE INVOCATION START HERE + self._call_hook_safely(BeforeNodeInvocationEvent(self, next_node_to_execute=current_node.node_id)) + logger.debug( "current_node=<%s>, iteration=<%d> | executing node", current_node.node_id, @@ -580,6 +630,8 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: logger.debug("node=<%s> | node execution completed", current_node.node_id) + self._call_hook_safely(AfterNodeInvocationEvent(self, executed_node=current_node.node_id)) + # Check if the current node is still the same after execution # If it is, then no handoff occurred and we consider the swarm complete if self.state.current_node == current_node: @@ -606,6 +658,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: self.state.completion_status = Status.FAILED elapsed_time = time.time() - self.state.start_time + self._call_hook_safely(AfterMultiAgentInvocationEvent(orchestrator=self)) logger.debug("status=<%s> | swarm execution completed", self.state.completion_status) logger.debug( "node_history_length=<%d>, time=<%s>s | metrics", @@ -625,15 +678,11 @@ async def _execute_node( context_text = self._build_node_input(node) node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] - # Clear handoff message after it's been included in context - self.state.handoff_message = None - if not isinstance(task, str): # Include additional ContentBlocks in node input node_input = node_input + task # Execute node - result = None node.reset_executor_state() # Unpacking since this is the agent class. Other executors should not unpack result = await node.executor.invoke_async(node_input, **invocation_state) @@ -664,6 +713,9 @@ async def _execute_node( # Accumulate metrics self._accumulate_metrics(node_result) + # Clear handoff message after it's been included in context + self.state.handoff_message = None + return result except Exception as e: @@ -683,6 +735,9 @@ async def _execute_node( # Store result in state self.state.results[node_name] = node_result + # Persist failure here + self._call_hook_safely(AfterNodeInvocationEvent(self, executed_node=node_name)) + raise def _accumulate_metrics(self, node_result: NodeResult) -> None: @@ -703,3 +758,112 @@ def _build_result(self) -> SwarmResult: execution_time=self.state.execution_time, node_history=self.state.node_history, ) + + # Persistence Helper function + + def _initial_node(self) -> SwarmNode: + if self.entry_point: + return self.nodes[str(self.entry_point.name)] + + return next(iter(self.nodes.values())) # First SwarmNode + + def _load_and_apply_persisted_state(self) -> None: + if self.session_manager is None: + return + try: + saved = self.session_manager.read_multi_agent_json() + except Exception as e: + logger.warning("Skipping resume; failed to load state: %s", e) + return + if not saved: + return + + try: + # saved is a dict + status = saved.get("status") + self.apply_state_from_dict(saved) + + self._resume_from_persisted = True + if status in (Status.COMPLETED.value, "completed"): + self._resume_from_completed = True + # Create a placeholder node to avoid None assignment + placeholder_node = SwarmNode("placeholder", Agent()) + self.state.current_node = placeholder_node + logger.debug("Saved state is COMPLETED; will return persisted result without re-running.") + else: + logger.debug( + "Resumed from persisted state. Current node: %s", + self.state.current_node.node_id if self.state.current_node else "None", + ) + except Exception as e: + logger.exception("Failed to hydrate swarm from persisted state: %s", e) + + def get_state_from_orchestrator(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state. + + Returns: + Dictionary containing the current swarm state + """ + status_str = getattr(self.state.completion_status, "value", str(self.state.completion_status)) + next_nodes = ( + [self.state.current_node.node_id] + if self.state.completion_status == Status.EXECUTING and self.state.current_node + else [] + ) + return { + "type": "swarm", + "status": status_str, + "completed_nodes": [n.node_id for n in self.state.node_history], + "node_results": { + k: self.summarize_node_result_for_persist(v) for k, v in (self.state.results or {}).items() + }, + "next_node_to_execute": next_nodes, + "current_task": self.state.task, + "execution_order": [n.node_id for n in self.state.node_history], + "context": { + "shared_context": getattr(self.state.shared_context, "context", {}) or {}, + "handoff_message": self.state.handoff_message, + }, + } + + def apply_state_from_dict(self, payload: dict) -> None: + """Restore orchestrator state from a session dict. + + Args: + payload: Dictionary containing persisted state data + """ + try: + status_raw = payload.get("status", "pending") + try: + self.state.completion_status = Status(status_raw) + except Exception: + self.state.completion_status = Status.PENDING + + ctx = payload.get("context") or {} + self.shared_context.context = ctx.get("shared_context") or {} + self.state.handoff_message = ctx.get("handoff_message") + + # node history and results + self.state.node_history = [ + self.nodes[nid] for nid in (payload.get("completed_nodes") or []) if nid in self.nodes + ] + self.state.results = dict(payload.get("node_results") or {}) + self.state.task = payload.get("current_task", self.state.task) + + # Determine current node (if executing) + next_ids = list(payload.get("next_node_to_execute") or []) + if next_ids: + nid = next_ids[0] + found_node = self.nodes.get(nid) + self.state.current_node = found_node if found_node is not None else self._initial_node() + else: + # fallback to last executed or first node + last = (payload.get("execution_order") or [])[-1:] or [] + if last: + found_node = self.nodes.get(last[0]) + self.state.current_node = found_node if found_node is not None else self._initial_node() + else: + self.state.current_node = self._initial_node() + + except Exception as e: + logger.exception("Failed to apply persisted swarm state: %s", e) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 93adeb7f2..b73e3cd6f 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -6,11 +6,12 @@ import os import shutil import tempfile +from datetime import datetime, timezone from typing import Any, Optional, cast from .. import _identifier from ..types.exceptions import SessionException -from ..types.session import Session, SessionAgent, SessionMessage +from ..types.session import Session, SessionAgent, SessionMessage, SessionType from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository @@ -38,19 +39,27 @@ class FileSessionManager(RepositorySessionManager, SessionRepository): ``` """ - def __init__(self, session_id: str, storage_dir: Optional[str] = None, **kwargs: Any): + def __init__( + self, + session_id: str, + storage_dir: Optional[str] = None, + *, + session_type: SessionType = SessionType.AGENT, + **kwargs: Any, + ): """Initialize FileSession with filesystem storage. Args: session_id: ID for the session. ID is not allowed to contain path separators (e.g., a/b). storage_dir: Directory for local filesystem storage (defaults to temp dir). + session_type: single agent or multiagent. **kwargs: Additional keyword arguments for future extensibility. """ self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions") os.makedirs(self.storage_dir, exist_ok=True) - super().__init__(session_id=session_id, session_repository=self) + super().__init__(session_id=session_id, session_repository=self, session_type=session_type) def _get_session_path(self, session_id: str) -> str: """Get session directory path. @@ -108,8 +117,12 @@ def _read_file(self, path: str) -> dict[str, Any]: def _write_file(self, path: str, data: dict[str, Any]) -> None: """Write JSON file.""" os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as f: + tmp = f"{path}.tmp" + with open(tmp, "w", encoding="utf-8", newline="\n") as f: json.dump(data, f, indent=2, ensure_ascii=False) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" @@ -118,8 +131,9 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: raise SessionException(f"Session {session.session_id} already exists") # Create directory structure - os.makedirs(session_dir, exist_ok=True) - os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) + if self.session_type == SessionType.AGENT: + os.makedirs(session_dir, exist_ok=True) + os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) # Write session file session_file = os.path.join(session_dir, "session.json") @@ -249,3 +263,33 @@ async def load_message(filename: str) -> SessionMessage: messages = await asyncio.gather(*tasks) return messages + + def write_multi_agent_json(self, state: dict[str, Any], **kwargs: Any) -> None: + """Write multi-agent state to filesystem. + + Args: + state: Multi-agent state dictionary to persist + **kwargs: Additional keyword arguments for future extensibility + """ + state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") + self._write_file(state_path, state) + + # Update session update metadata + session_dir = self._get_session_path(self.session.session_id) + session_file = os.path.join(session_dir, "session.json") + with open(session_file, "r", encoding="utf-8") as f: + metadata = json.load(f) + metadata["updated_at"] = datetime.now(timezone.utc).isoformat() + self._write_file(session_file, metadata) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from filesystem. + + Returns: + Multi-agent state dictionary or None if not found + """ + state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") + if not os.path.exists(state_path): + return {} + state_data = self._read_file(state_path) + return state_data diff --git a/src/strands/session/repository_session_manager.py b/src/strands/session/repository_session_manager.py index 75058b251..19cd3b39e 100644 --- a/src/strands/session/repository_session_manager.py +++ b/src/strands/session/repository_session_manager.py @@ -24,7 +24,13 @@ class RepositorySessionManager(SessionManager): """Session manager for persisting agents in a SessionRepository.""" - def __init__(self, session_id: str, session_repository: SessionRepository, **kwargs: Any): + def __init__( + self, + session_id: str, + session_repository: SessionRepository, + session_type: SessionType = SessionType.AGENT, + **kwargs: Any, + ): """Initialize the RepositorySessionManager. If no session with the specified session_id exists yet, it will be created @@ -34,22 +40,27 @@ def __init__(self, session_id: str, session_repository: SessionRepository, **kwa session_id: ID to use for the session. A new session with this id will be created if it does not exist in the repository yet session_repository: Underlying session repository to use to store the sessions state. + session_type: Type of session (AGENT or MULTI_AGENT) **kwargs: Additional keyword arguments for future extensibility. """ + super().__init__(session_type=session_type) + self.session_repository = session_repository self.session_id = session_id session = session_repository.read_session(session_id) # Create a session if it does not exist yet if session is None: logger.debug("session_id=<%s> | session not found, creating new session", self.session_id) - session = Session(session_id=session_id, session_type=SessionType.AGENT) + session = Session(session_id=session_id, session_type=session_type) session_repository.create_session(session) self.session = session + self.session_type = session.session_type - # Keep track of the latest message of each agent in case we need to redact it. - self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} + if self.session_type == SessionType.AGENT: + # Keep track of the latest message of each agent in case we need to redact it. + self._latest_agent_message: dict[str, Optional[SessionMessage]] = {} def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None: """Append a message to the agent's session. diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 1f6ffe7f1..915769da0 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -3,6 +3,7 @@ import asyncio import json import logging +from datetime import datetime, timezone from typing import Any, Dict, List, Optional, cast import boto3 @@ -11,7 +12,7 @@ from .. import _identifier from ..types.exceptions import SessionException -from ..types.session import Session, SessionAgent, SessionMessage +from ..types.session import Session, SessionAgent, SessionMessage, SessionType from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository @@ -47,6 +48,7 @@ def __init__( boto_session: Optional[boto3.Session] = None, boto_client_config: Optional[BotocoreConfig] = None, region_name: Optional[str] = None, + session_type: SessionType = SessionType.AGENT, **kwargs: Any, ): """Initialize S3SessionManager with S3 storage. @@ -59,6 +61,7 @@ def __init__( boto_session: Optional boto3 session boto_client_config: Optional boto3 client configuration region_name: AWS region for S3 storage + session_type: Type of session (AGENT or MULTI_AGENT) **kwargs: Additional keyword arguments for future extensibility. """ self.bucket = bucket @@ -79,7 +82,12 @@ def __init__( client_config = BotocoreConfig(user_agent_extra="strands-agents") self.client = session.client(service_name="s3", config=client_config) - super().__init__(session_id=session_id, session_repository=self) + super().__init__(session_id=session_id, session_type=session_type, session_repository=self) + + # This avoids leading // or / when self.prefix ="" + def _join_key(self, *parts: str) -> str: + cleaned = [part.strip("/ ") for part in parts if part and part.strip("/ ")] + return "/".join(cleaned) def _get_session_path(self, session_id: str) -> str: """Get session S3 prefix. @@ -91,7 +99,7 @@ def _get_session_path(self, session_id: str) -> str: ValueError: If session id contains a path separator. """ session_id = _identifier.validate(session_id, _identifier.Identifier.SESSION) - return f"{self.prefix}/{SESSION_PREFIX}{session_id}/" + return self._join_key(self.prefix, f"{SESSION_PREFIX}{session_id}") def _get_agent_path(self, session_id: str, agent_id: str) -> str: """Get agent S3 prefix. @@ -304,3 +312,33 @@ async def load_message(key: str) -> Optional[SessionMessage]: loaded_messages = await asyncio.gather(*tasks) return [msg for msg in loaded_messages if msg is not None] + + def write_multi_agent_json(self, state: dict[str, Any]) -> None: + """Write multi-agent state to S3. + + Args: + state: Multi-agent state dictionary to persist + """ + session_prefix = self._get_session_path(self.session_id) + state_key = self._join_key(session_prefix, "multi_agent_state.json") + self._write_s3_object(state_key, state) + + # Touch updated_at on session.json (best-effort) + session_key = self._join_key(session_prefix, "session.json") + try: + metadata = self._read_s3_object(session_key) or {} + metadata["updated_at"] = datetime.now(timezone.utc).isoformat() + self._write_s3_object(session_key, metadata) + except SessionException: + # If session.json is missing or unreadable, don't fail persistence + logger.warning("Could not update session.json updated_at for session %s", self.session_id) + + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from S3. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + session_prefix = self._get_session_path(self.session_id) + state_key = self._join_key(session_prefix, "multi_agent_state.json") + return self._read_s3_object(state_key) or {} diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 66a07ea43..b14f2d467 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -6,6 +6,7 @@ from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message +from ..types.session import SessionType if TYPE_CHECKING: from ..agent.agent import Agent @@ -20,19 +21,28 @@ class SessionManager(HookProvider, ABC): for an agent, and should be persisted in the session. """ + def __init__(self, session_type: SessionType = SessionType.AGENT) -> None: + """Initialize SessionManager with session type. + + Args: + session_type: Type of session (AGENT or MULTI_AGENT) + """ + self.session_type: SessionType = session_type + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for persisting the agent to the session.""" - # After the normal Agent initialization behavior, call the session initialize function to restore the agent - registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) + if self.session_type == SessionType.AGENT: + # After the normal Agent initialization behavior, call the session initialize function to restore the agent + registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) - # For each message appended to the Agents messages, store that message in the session - registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) + # For each message appended to the Agents messages, store that message in the session + registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent)) - # Sync the agent into the session for each message in case the agent state was updated - registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) + # Sync the agent into the session for each message in case the agent state was updated + registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent)) - # After an agent was invoked, sync it with the session to capture any conversation manager state updates - registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) + # After an agent was invoked, sync it with the session to capture any conversation manager state updates + registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: @@ -71,3 +81,22 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: agent: Agent to initialize **kwargs: Additional keyword arguments for future extensibility. """ + + # Multiagent abstract functions + @abstractmethod + def write_multi_agent_json(self, state: dict[str, Any]) -> None: + """Write multi-agent state to persistent storage. + + Args: + state: Multi-agent state dictionary to persist + """ + raise NotImplementedError + + @abstractmethod + def read_multi_agent_json(self) -> dict[str, Any]: + """Read multi-agent state from persistent storage. + + Returns: + Multi-agent state dictionary or empty dict if not found + """ + raise NotImplementedError diff --git a/src/strands/types/session.py b/src/strands/types/session.py index e51816f74..ffd6ea1f7 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -21,6 +21,7 @@ class SessionType(str, Enum): """ AGENT = "AGENT" + MULTI_AGENT = "MULTI_AGENT" def encode_bytes_values(obj: Any) -> Any: diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py index f3923f68b..4d4c93c8e 100644 --- a/tests/fixtures/mock_session_repository.py +++ b/tests/fixtures/mock_session_repository.py @@ -1,3 +1,4 @@ +from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository from strands.types.exceptions import SessionException from strands.types.session import SessionAgent, SessionMessage @@ -95,3 +96,15 @@ def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[Sess if limit is not None: return sorted_messages[offset : offset + limit] return sorted_messages[offset:] + + +class TestRepositorySessionManager(RepositorySessionManager): + """Test implementation of RepositorySessionManager with concrete multi-agent methods.""" + + def write_multi_agent_json(self, state: dict) -> None: + """Write multi-agent state (no-op for testing).""" + pass + + def read_multi_agent_json(self) -> dict: + """Read multi-agent state (returns empty dict for testing).""" + return {} diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 2cd87c26d..0e6d6dbcd 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -17,13 +17,12 @@ from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel -from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType -from tests.fixtures.mock_session_repository import MockedSessionRepository +from tests.fixtures.mock_session_repository import MockedSessionRepository, TestRepositorySessionManager from tests.fixtures.mocked_model_provider import MockedModelProvider # For unit testing we will use the the us inference @@ -1529,7 +1528,7 @@ def test_agent_state_get_breaks_deep_dict_reference(): def test_agent_session_management(): mock_session_repository = MockedSessionRepository() - session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) agent = Agent(session_manager=session_manager, model=model) agent("Hello!") @@ -1546,7 +1545,7 @@ def test_agent_restored_from_session_management(): conversation_manager_state=SlidingWindowConversationManager().get_state(), ), ) - session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager) @@ -1567,7 +1566,7 @@ def test_agent_restored_from_session_management_with_message(): mock_session_repository.create_message( "123", "default", SessionMessage({"role": "user", "content": [{"text": "Hello!"}]}, 0) ) - session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager) @@ -1598,7 +1597,9 @@ def test_agent_restored_from_session_management_with_redacted_input(): test_session_id = str(uuid4()) mocked_session_repository = MockedSessionRepository() - session_manager = RepositorySessionManager(session_id=test_session_id, session_repository=mocked_session_repository) + session_manager = TestRepositorySessionManager( + session_id=test_session_id, session_repository=mocked_session_repository + ) agent = Agent( model=mocked_model, @@ -1618,7 +1619,7 @@ def test_agent_restored_from_session_management_with_redacted_input(): assert user_input_session_message.to_message() == agent.messages[0] # Restore an agent from the session, confirm input is still redacted - session_manager_2 = RepositorySessionManager( + session_manager_2 = TestRepositorySessionManager( session_id=test_session_id, session_repository=mocked_session_repository ) agent_2 = Agent( @@ -1637,13 +1638,13 @@ def test_agent_restored_from_session_management_with_correct_index(): [{"role": "assistant", "content": [{"text": "hello!"}]}, {"role": "assistant", "content": [{"text": "world!"}]}] ) mock_session_repository = MockedSessionRepository() - session_manager = RepositorySessionManager(session_id="test", session_repository=mock_session_repository) + session_manager = TestRepositorySessionManager(session_id="test", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager, model=mock_model_provider) agent("Hello!") assert len(mock_session_repository.list_messages("test", agent.agent_id)) == 2 - session_manager_2 = RepositorySessionManager(session_id="test", session_repository=mock_session_repository) + session_manager_2 = TestRepositorySessionManager(session_id="test", session_repository=mock_session_repository) agent_2 = Agent(session_manager=session_manager_2, model=mock_model_provider) assert len(agent_2.messages) == 2 @@ -1661,7 +1662,7 @@ def test_agent_restored_from_session_management_with_correct_index(): def test_agent_with_session_and_conversation_manager(): mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) mock_session_repository = MockedSessionRepository() - session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) conversation_manager = SlidingWindowConversationManager(window_size=1) # Create an agent with a mocked model and session repository agent = Agent( @@ -1683,7 +1684,7 @@ def test_agent_with_session_and_conversation_manager(): assert len(agent.messages) == 1 # Initialize another agent using the same session - session_manager_2 = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager_2 = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) conversation_manager_2 = SlidingWindowConversationManager(window_size=1) agent_2 = Agent( session_manager=session_manager_2, diff --git a/tests/strands/experimental/multiagent_hooks/__init__.py b/tests/strands/experimental/multiagent_hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py b/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py new file mode 100644 index 000000000..f983c0602 --- /dev/null +++ b/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py @@ -0,0 +1,122 @@ +"""Tests for multi-agent execution lifecycle events.""" + +from unittest.mock import Mock + +import pytest + +from strands.experimental.multiagent_hooks.multiagent_events import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeInvocationEvent, + MultiAgentInitializationEvent, +) +from strands.hooks.registry import BaseHookEvent + + +@pytest.fixture +def orchestrator(): + """Mock orchestrator for testing.""" + return Mock() + + +def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): + """Test MultiAgentInitializationEvent creation with orchestrator only.""" + event = MultiAgentInitializationEvent(orchestrator=orchestrator) + + assert event.orchestrator is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_multi_agent_initialization_event_with_invocation_state(orchestrator): + """Test MultiAgentInitializationEvent creation with invocation state.""" + invocation_state = {"key": "value"} + event = MultiAgentInitializationEvent(orchestrator=orchestrator, invocation_state=invocation_state) + + assert event.orchestrator is orchestrator + assert event.invocation_state == invocation_state + + +def test_before_multi_agent_invocation_event_with_orchestrator_only(orchestrator): + """Test BeforeMultiAgentInvocationEvent creation with orchestrator only.""" + event = BeforeMultiAgentInvocationEvent(orchestrator=orchestrator) + + assert event.orchestrator is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_before_multi_agent_invocation_event_with_invocation_state(orchestrator): + """Test BeforeMultiAgentInvocationEvent creation with invocation state.""" + invocation_state = {"config": "test"} + event = BeforeMultiAgentInvocationEvent(orchestrator=orchestrator, invocation_state=invocation_state) + + assert event.orchestrator is orchestrator + assert event.invocation_state == invocation_state + + +def test_before_node_invocation_event_with_required_fields(orchestrator): + """Test BeforeNodeInvocationEvent creation with required fields.""" + next_node = "node_1" + event = BeforeNodeInvocationEvent(orchestrator=orchestrator, next_node_to_execute=next_node) + + assert event.orchestrator is orchestrator + assert event.next_node_to_execute == next_node + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_before_node_invocation_event_with_invocation_state(orchestrator): + """Test BeforeNodeInvocationEvent creation with invocation state.""" + next_node = "node_2" + invocation_state = {"step": 1} + event = BeforeNodeInvocationEvent( + orchestrator=orchestrator, next_node_to_execute=next_node, invocation_state=invocation_state + ) + + assert event.orchestrator is orchestrator + assert event.next_node_to_execute == next_node + assert event.invocation_state == invocation_state + + +def test_after_node_invocation_event_with_required_fields(orchestrator): + """Test AfterNodeInvocationEvent creation with required fields.""" + executed_node = "node_1" + event = AfterNodeInvocationEvent(orchestrator=orchestrator, executed_node=executed_node) + + assert event.orchestrator is orchestrator + assert event.executed_node == executed_node + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_node_invocation_event_with_invocation_state(orchestrator): + """Test AfterNodeInvocationEvent creation with invocation state.""" + executed_node = "node_2" + invocation_state = {"result": "success"} + event = AfterNodeInvocationEvent( + orchestrator=orchestrator, executed_node=executed_node, invocation_state=invocation_state + ) + + assert event.orchestrator is orchestrator + assert event.executed_node == executed_node + assert event.invocation_state == invocation_state + + +def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with orchestrator only.""" + event = AfterMultiAgentInvocationEvent(orchestrator=orchestrator) + + assert event.orchestrator is orchestrator + assert event.invocation_state is None + assert isinstance(event, BaseHookEvent) + + +def test_after_multi_agent_invocation_event_with_invocation_state(orchestrator): + """Test AfterMultiAgentInvocationEvent creation with invocation state.""" + invocation_state = {"final_state": "completed"} + event = AfterMultiAgentInvocationEvent(orchestrator=orchestrator, invocation_state=invocation_state) + + assert event.orchestrator is orchestrator + assert event.invocation_state == invocation_state diff --git a/tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py b/tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py new file mode 100644 index 000000000..30ca9d214 --- /dev/null +++ b/tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py @@ -0,0 +1,131 @@ +"""Tests for multi-agent session persistence hook implementation.""" + +from unittest.mock import Mock, call + +import pytest + +from strands.experimental.multiagent_hooks.multiagent_events import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeInvocationEvent, + MultiAgentInitializationEvent, +) +from strands.experimental.multiagent_hooks.persistence_hooks import PersistentHook +from strands.hooks.registry import HookRegistry + + +@pytest.fixture +def session_manager(): + """Mock session manager.""" + return Mock() + + +@pytest.fixture +def orchestrator(): + """Mock orchestrator.""" + mock_orchestrator = Mock() + mock_orchestrator.get_state_from_orchestrator.return_value = {"state": "test"} + return mock_orchestrator + + +@pytest.fixture +def hook(session_manager): + """PersistentHook instance.""" + return PersistentHook(session_manager) + + +def test_initialization(session_manager): + """Test hook initialization.""" + hook = PersistentHook(session_manager) + + assert hook._session_manager is session_manager + assert hasattr(hook, "_lock") + + +def test_register_hooks(hook): + """Test hook registration with registry.""" + registry = Mock(spec=HookRegistry) + + hook.register_hooks(registry) + + expected_calls = [ + call(MultiAgentInitializationEvent, hook._on_initialization), + call(BeforeMultiAgentInvocationEvent, hook._on_before_multiagent), + call(BeforeNodeInvocationEvent, hook._on_before_node), + call(AfterNodeInvocationEvent, hook._on_after_node), + call(AfterMultiAgentInvocationEvent, hook._on_after_multiagent), + ] + registry.add_callback.assert_has_calls(expected_calls) + + +def test_on_initialization_persists_state(hook, orchestrator): + """Test initialization event triggers persistence.""" + event = MultiAgentInitializationEvent(orchestrator=orchestrator) + + hook._on_initialization(event) + + orchestrator.get_state_from_orchestrator.assert_called_once() + hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) + + +def test_on_before_multiagent_does_nothing(hook, orchestrator): + """Test before multiagent event does nothing.""" + event = BeforeMultiAgentInvocationEvent(orchestrator=orchestrator) + + hook._on_before_multiagent(event) + + orchestrator.get_state_from_orchestrator.assert_not_called() + hook._session_manager.write_multi_agent_json.assert_not_called() + + +def test_on_before_node_does_nothing(hook, orchestrator): + """Test before node event does nothing.""" + event = BeforeNodeInvocationEvent(orchestrator=orchestrator, next_node_to_execute="node_1") + + hook._on_before_node(event) + + orchestrator.get_state_from_orchestrator.assert_not_called() + hook._session_manager.write_multi_agent_json.assert_not_called() + + +def test_on_after_node_persists_state(hook, orchestrator): + """Test after node event triggers persistence.""" + event = AfterNodeInvocationEvent(orchestrator=orchestrator, executed_node="node_1") + + hook._on_after_node(event) + + orchestrator.get_state_from_orchestrator.assert_called_once() + hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) + + +def test_on_after_multiagent_persists_state(hook, orchestrator): + """Test after multiagent event triggers persistence.""" + event = AfterMultiAgentInvocationEvent(orchestrator=orchestrator) + + hook._on_after_multiagent(event) + + orchestrator.get_state_from_orchestrator.assert_called_once() + hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) + + +def test_persist_thread_safety(hook, orchestrator): + """Test that persistence operations are thread-safe.""" + hook._lock = Mock() + hook._lock.__enter__ = Mock(return_value=hook._lock) + hook._lock.__exit__ = Mock(return_value=None) + + hook._persist(orchestrator) + + hook._lock.__enter__.assert_called_once() + hook._lock.__exit__.assert_called_once() + orchestrator.get_state_from_orchestrator.assert_called_once() + hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) + + +def test_persist_gets_state_and_writes(hook, orchestrator): + """Test persist method gets state and writes to session manager.""" + hook._persist(orchestrator) + + orchestrator.get_state_from_orchestrator.assert_called_once() + hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index d21aa6e14..a5a4a10aa 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -28,6 +28,10 @@ def test_node_result_initialization_and_properties(agent_result): assert node_result.accumulated_metrics == {"latencyMs": 0.0} assert node_result.execution_count == 0 + # Test default status + default_node = NodeResult(result=agent_result) + assert default_node.status == Status.PENDING + # With custom metrics custom_usage = {"inputTokens": 100, "outputTokens": 200, "totalTokens": 300} custom_metrics = {"latencyMs": 250.0} @@ -85,6 +89,16 @@ def test_node_result_get_agent_results(agent_result): assert "Response 1" in response_texts assert "Response 2" in response_texts + # Test with object that has AgentResult class name but isn't AgentResult + from unittest.mock import Mock + + mock_result = Mock() + mock_result.__class__.__name__ = "AgentResult" + node_result = NodeResult(result=mock_result) + agent_results = node_result.get_agent_results() + assert len(agent_results) == 1 + assert agent_results[0] == mock_result + def test_multi_agent_result_initialization(agent_result): """Test MultiAgentResult initialization with defaults and custom values.""" @@ -95,6 +109,7 @@ def test_multi_agent_result_initialization(agent_result): assert result.accumulated_metrics == {"latencyMs": 0.0} assert result.execution_count == 0 assert result.execution_time == 0 + assert result.status == Status.PENDING # Custom values`` node_result = NodeResult(result=agent_result) @@ -141,6 +156,12 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) + def get_state_from_orchestrator(self) -> dict: + return {} + + def apply_state_from_dict(self, payload: dict) -> None: + pass + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) @@ -163,6 +184,12 @@ async def invoke_async(self, task, invocation_state, **kwargs): status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) + def get_state_from_orchestrator(self) -> dict: + return {} + + def apply_state_from_dict(self, payload: dict) -> None: + pass + agent = TestMultiAgent() # Test with string task @@ -173,3 +200,42 @@ async def invoke_async(self, task, invocation_state, **kwargs): assert agent.received_kwargs == {"param1": "value1", "param2": "value2"} assert isinstance(result, MultiAgentResult) assert result.status == Status.COMPLETED + + +def test_summarize_node_result_for_persist(agent_result): + """Test summarize_node_result_for_persist method.""" + from unittest.mock import Mock + + agent = Mock(spec=MultiAgentBase) + + # Test with NodeResult containing AgentResult + node_result = NodeResult(result=agent_result) + summary = MultiAgentBase.summarize_node_result_for_persist(agent, node_result) + assert "agent_outputs" in summary + assert isinstance(summary["agent_outputs"], list) + + # Test with already normalized dict + normalized = {"agent_outputs": ["test1", "test2"]} + summary = MultiAgentBase.summarize_node_result_for_persist(agent, normalized) + assert summary == {"agent_outputs": ["test1", "test2"]} + + # Test fallback case + summary = MultiAgentBase.summarize_node_result_for_persist(agent, "simple string") + assert summary == {"agent_outputs": ["simple string"]} + + +def test_call_hook_safely(): + """Test _call_hook_safely method handles exceptions.""" + from unittest.mock import Mock + + agent = Mock(spec=MultiAgentBase) + agent.hooks = Mock() + event = Mock() + + # Test successful hook call + MultiAgentBase._call_hook_safely(agent, event) + agent.hooks.invoke_callbacks.assert_called_once_with(event) + + # Test hook exception handling + agent.hooks.invoke_callbacks.side_effect = Exception("Hook error") + MultiAgentBase._call_hook_safely(agent, event) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 8097d944e..9e98524fe 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -6,8 +6,7 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.hooks import AgentInitializedEvent -from strands.hooks.registry import HookProvider, HookRegistry +from strands.hooks.registry import HookRegistry from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.session_manager import SessionManager @@ -640,7 +639,7 @@ async def timeout_invoke(*args, **kwargs): graph = builder.set_max_node_executions(50).set_execution_timeout(900.0).set_node_timeout(0.1).build() # Execute the graph - should raise Exception due to timeout - with pytest.raises(Exception, match="Node 'timeout_node' execution timed out after 0.1s"): + with pytest.raises(asyncio.TimeoutError): await graph.invoke_async("Test node timeout") mock_strands_tracer.start_multiagent_span.assert_called() @@ -865,30 +864,16 @@ def test_graph_validate_unsupported_features(): graph = builder.build() assert len(graph.nodes) == 1 - # Test with session manager (should fail in GraphBuilder.add_node) + # Test with session manager (should work now - session persistence is supported) mock_session_manager = Mock(spec=SessionManager) agent_with_session = create_mock_agent("agent_with_session") agent_with_session._session_manager = mock_session_manager agent_with_session.hooks = HookRegistry() builder = GraphBuilder() - with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): - builder.add_node(agent_with_session) - - # Test with callbacks (should fail in GraphBuilder.add_node) - class TestHookProvider(HookProvider): - def register_hooks(self, registry, **kwargs): - registry.add_callback(AgentInitializedEvent, lambda e: None) - - # Test validation in Graph constructor (when nodes are passed directly) - # Test with session manager in Graph constructor - node_with_session = GraphNode("node_with_session", agent_with_session) - with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): - Graph( - nodes={"node_with_session": node_with_session}, - edges=set(), - entry_points=set(), - ) + builder.add_node(agent_with_session) + graph = builder.build() + assert len(graph.nodes) == 1 @pytest.mark.asyncio @@ -1139,7 +1124,7 @@ async def test_self_loop_functionality_without_reset(mock_strands_tracer, mock_u result = await graph.invoke_async("Test self loop without reset") assert result.status == Status.COMPLETED - assert len(result.execution_order) == 2 + assert len(result.execution_order) == 1 mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called() @@ -1337,3 +1322,54 @@ def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span): kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state) assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_graph_persisted(mock_strands_tracer, mock_use_span): + """Test graph persistence functionality.""" + # Create mock session manager + session_manager = Mock(spec=SessionManager) + session_manager.read_multi_agent_json.return_value = None + + # Create simple graph with session manager + builder = GraphBuilder() + agent = create_mock_agent("test_agent") + builder.add_node(agent, "test_node") + builder.set_entry_point("test_node") + builder.set_session_manager(session_manager) + + graph = builder.build() + + # Test get_state_from_orchestrator + state = graph.get_state_from_orchestrator() + assert state["type"] == "graph" + assert "status" in state + assert "completed_nodes" in state + assert "node_results" in state + + # Test apply_state_from_dict with persisted state + persisted_state = { + "status": "executing", + "completed_nodes": [], + "node_results": {}, + "current_task": "persisted task", + "execution_order": [], + "next_node_to_execute": ["test_node"], + } + + graph.apply_state_from_dict(persisted_state) + assert graph.state.task == "persisted task" + + # Execute graph to test persistence integration + result = await graph.invoke_async("Test persistence") + + # Verify execution completed + assert result.status == Status.COMPLETED + assert len(result.results) == 1 + assert "test_node" in result.results + + # Test state serialization after execution + final_state = graph.get_state_from_orchestrator() + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 1 + assert "test_node" in final_state["node_results"] diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 7d3e69695..dfb0b19e0 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -574,3 +574,53 @@ def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span): assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs assert result.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_swarm_persistence(mock_strands_tracer, mock_use_span): + """Test swarm persistence functionality.""" + # Create mock session manager + session_manager = Mock(spec=SessionManager) + session_manager.read_multi_agent_json.return_value = None + + # Create simple swarm with session manager + agent = create_mock_agent("test_agent") + swarm = Swarm([agent], session_manager=session_manager) + + # Test get_state_from_orchestrator + state = swarm.get_state_from_orchestrator() + assert state["type"] == "swarm" + assert "status" in state + assert "completed_nodes" in state + assert "node_results" in state + assert "context" in state + + # Test apply_state_from_dict with persisted state + persisted_state = { + "status": "executing", + "completed_nodes": [], + "node_results": {}, + "current_task": "persisted task", + "execution_order": [], + "next_node_to_execute": ["test_agent"], + "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, + } + + swarm.apply_state_from_dict(persisted_state) + assert swarm.state.task == "persisted task" + assert swarm.state.handoff_message == "test handoff" + assert swarm.shared_context.context["test_agent"]["key"] == "value" + + # Execute swarm to test persistence integration + result = await swarm.invoke_async("Test persistence") + + # Verify execution completed + assert result.status == Status.COMPLETED + assert len(result.results) == 1 + assert "test_agent" in result.results + + # Test state serialization after execution + final_state = swarm.get_state_from_orchestrator() + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 1 + assert "test_agent" in final_state["node_results"] diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index f124ddf58..6db7c1e6d 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -408,3 +408,45 @@ def test__get_message_path_invalid_message_id(message_id, file_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): file_manager._get_message_path("session1", "agent1", message_id) + + +def test_write_read_multi_agent_json(file_manager, sample_session): + """Test writing and reading multi-agent state.""" + file_manager.create_session(sample_session) + + # Write multi-agent state + state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + file_manager.write_multi_agent_json(state) + + # Read multi-agent state + result = file_manager.read_multi_agent_json() + assert result == state + + +def test_read_multi_agent_json_nonexistent(file_manager): + """Test reading multi-agent state when file doesn't exist.""" + result = file_manager.read_multi_agent_json() + assert result == {} + + +def test_list_messages_missing_directory(file_manager, sample_session, sample_agent): + """Test listing messages when messages directory is missing.""" + file_manager.create_session(sample_session) + file_manager.create_agent(sample_session.session_id, sample_agent) + + # Remove messages directory + messages_dir = os.path.join( + file_manager._get_agent_path(sample_session.session_id, sample_agent.agent_id), "messages" + ) + os.rmdir(messages_dir) + + with pytest.raises(SessionException, match="Messages directory missing"): + file_manager.list_messages(sample_session.session_id, sample_agent.agent_id) + + +def test_create_existing_session(file_manager, sample_session): + """Test creating a session that already exists.""" + file_manager.create_session(sample_session) + + with pytest.raises(SessionException, match="already exists"): + file_manager.create_session(sample_session) diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 2c25fcc38..77d2e4756 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,11 +5,10 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager -from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType -from tests.fixtures.mock_session_repository import MockedSessionRepository +from tests.fixtures.mock_session_repository import MockedSessionRepository, TestRepositorySessionManager @pytest.fixture @@ -21,7 +20,7 @@ def mock_repository(): @pytest.fixture def session_manager(mock_repository): """Create a session manager with mock repository.""" - return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + return TestRepositorySessionManager(session_id="test-session", session_repository=mock_repository) @pytest.fixture @@ -36,7 +35,7 @@ def test_init_creates_session_if_not_exists(mock_repository): assert mock_repository.read_session("test-session") is None # Creating manager should create session - RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + TestRepositorySessionManager(session_id="test-session", session_repository=mock_repository) # Verify session created session = mock_repository.read_session("test-session") @@ -52,7 +51,7 @@ def test_init_uses_existing_session(mock_repository): mock_repository.create_session(session) # Creating manager should use existing session - manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) + manager = TestRepositorySessionManager(session_id="test-session", session_repository=mock_repository) # Verify session used assert manager.session == session diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index c4d6a0154..a6aa0b9b4 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -374,3 +374,20 @@ def test__get_message_path_invalid_message_id(message_id, s3_manager): """Test that message_id that is not an integer raises ValueError.""" with pytest.raises(ValueError, match=r"message_id=<.*> \| message id must be an integer"): s3_manager._get_message_path("session1", "agent1", message_id) + + +def test_write_read_multi_agent_json(s3_manager, sample_session): + """Test multi-agent state persistence.""" + s3_manager.create_session(sample_session) + + state = {"type": "graph", "status": "completed"} + s3_manager.write_multi_agent_json(state) + + result = s3_manager.read_multi_agent_json() + assert result == state + + +def test_read_multi_agent_json_nonexistent(s3_manager): + """Test reading multi-agent state when file doesn't exist.""" + result = s3_manager.read_multi_agent_json() + assert result == {} From 41d34a91ca7dc1d79b4590d3353faafa9e85d274 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 29 Sep 2025 10:44:01 -0400 Subject: [PATCH 03/27] fix: add write file fallback for window permission handling --- src/strands/session/file_session_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index b73e3cd6f..c2a668204 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -122,7 +122,12 @@ def _write_file(self, path: str, data: dict[str, Any]) -> None: json.dump(data, f, indent=2, ensure_ascii=False) f.flush() os.fsync(f.fileno()) - os.replace(tmp, path) + try: + os.replace(tmp, path) + except PermissionError: + # Windows fallback: copy+delete if atomic replace fails + shutil.copy2(tmp, path) + os.remove(tmp) def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" From 2e854f8742be848046c460e770df1cfd85d29c79 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 30 Sep 2025 19:12:41 -0400 Subject: [PATCH 04/27] fix: remove persistence_hooks, use session_manager to subscribe multiagent_hook events, adjust node_result serialization --- .../experimental/multiagent_hooks/__init__.py | 6 - .../multiagent_hooks/multiagent_events.py | 27 ---- .../multiagent_hooks/persistence_hooks.py | 91 ------------ src/strands/hooks/registry.py | 18 ++- src/strands/multiagent/base.py | 113 +++++++-------- src/strands/multiagent/graph.py | 136 +++++++----------- src/strands/multiagent/swarm.py | 38 +++-- src/strands/session/s3_session_manager.py | 11 +- src/strands/session/session_manager.py | 29 ++++ .../test_multiagent_events.py | 44 ------ .../test_persistence_hooks.py | 131 ----------------- tests/strands/multiagent/test_base.py | 78 ++++++---- tests/strands/multiagent/test_graph.py | 6 +- tests/strands/multiagent/test_swarm.py | 6 +- 14 files changed, 226 insertions(+), 508 deletions(-) delete mode 100644 src/strands/experimental/multiagent_hooks/persistence_hooks.py delete mode 100644 tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py diff --git a/src/strands/experimental/multiagent_hooks/__init__.py b/src/strands/experimental/multiagent_hooks/__init__.py index 0567effad..a8825f02e 100644 --- a/src/strands/experimental/multiagent_hooks/__init__.py +++ b/src/strands/experimental/multiagent_hooks/__init__.py @@ -7,17 +7,11 @@ from .multiagent_events import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - BeforeMultiAgentInvocationEvent, - BeforeNodeInvocationEvent, MultiAgentInitializationEvent, ) -from .persistence_hooks import PersistentHook __all__ = [ - "BeforeMultiAgentInvocationEvent", "AfterMultiAgentInvocationEvent", "MultiAgentInitializationEvent", - "BeforeNodeInvocationEvent", "AfterNodeInvocationEvent", - "PersistentHook", ] diff --git a/src/strands/experimental/multiagent_hooks/multiagent_events.py b/src/strands/experimental/multiagent_hooks/multiagent_events.py index 988799a3f..931998cde 100644 --- a/src/strands/experimental/multiagent_hooks/multiagent_events.py +++ b/src/strands/experimental/multiagent_hooks/multiagent_events.py @@ -27,33 +27,6 @@ class MultiAgentInitializationEvent(BaseHookEvent): invocation_state: dict[str, Any] | None = None -@dataclass -class BeforeMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered before orchestrator execution begins. - - Attributes: - orchestrator: The multi-agent orchestrator instance - invocation_state: Configuration that user pass in - """ - - orchestrator: "MultiAgentBase" - invocation_state: dict[str, Any] | None = None - - -@dataclass -class BeforeNodeInvocationEvent(BaseHookEvent): - """Event triggered before individual node execution. - - Attributes: - orchestrator: The multi-agent orchestrator instance - invocation_state: Configuration that user pass in - """ - - orchestrator: "MultiAgentBase" - next_node_to_execute: str - invocation_state: dict[str, Any] | None = None - - @dataclass class AfterNodeInvocationEvent(BaseHookEvent): """Event triggered after individual node execution completes. diff --git a/src/strands/experimental/multiagent_hooks/persistence_hooks.py b/src/strands/experimental/multiagent_hooks/persistence_hooks.py deleted file mode 100644 index cbd168e65..000000000 --- a/src/strands/experimental/multiagent_hooks/persistence_hooks.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Multi-agent session persistence hook implementation. - -This module provides automatic session persistence for multi-agent orchestrators -(Graph and Swarm) by hooking into their execution lifecycle events. - -Key Features: -- Automatic state persistence at key execution points -- Thread-safe persistence operations -- Support for both Graph and Swarm orchestrators -- Seamless integration with SessionManager -""" - -import threading -from typing import TYPE_CHECKING - -from ...hooks.registry import HookProvider, HookRegistry -from ...session import SessionManager -from .multiagent_events import ( - AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - BeforeMultiAgentInvocationEvent, - BeforeNodeInvocationEvent, - MultiAgentInitializationEvent, -) - -if TYPE_CHECKING: - from ...multiagent.base import MultiAgentBase - - -class PersistentHook(HookProvider): - """Hook provider for automatic multi-agent session persistence. - - This hook automatically persists multi-agent orchestrator state at key - execution points to enable resumable execution after interruptions. - - """ - - def __init__(self, session_manager: SessionManager): - """Initialize the multi-agent persistence hook. - - Args: - session_manager: SessionManager instance for state persistence - """ - self._session_manager = session_manager - self._lock = threading.RLock() - - def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: - """Register persistence callbacks for multi-agent execution events. - - Args: - registry: Hook registry to register callbacks with - **kwargs: Additional keyword arguments (unused) - """ - registry.add_callback(MultiAgentInitializationEvent, self._on_initialization) - registry.add_callback(BeforeMultiAgentInvocationEvent, self._on_before_multiagent) - registry.add_callback(BeforeNodeInvocationEvent, self._on_before_node) - registry.add_callback(AfterNodeInvocationEvent, self._on_after_node) - registry.add_callback(AfterMultiAgentInvocationEvent, self._on_after_multiagent) - - # TODO: We can add **kwarg or invocation_state later if we need to persist - def _on_initialization(self, event: MultiAgentInitializationEvent) -> None: - """Persist state when multi-agent orchestrator initializes.""" - self._persist(event.orchestrator) - - def _on_before_multiagent(self, event: BeforeMultiAgentInvocationEvent) -> None: - """Persist state when multi-agent orchestrator initializes.""" - pass - - def _on_before_node(self, event: BeforeNodeInvocationEvent) -> None: - """Hook called before individual node execution.""" - pass - - def _on_after_node(self, event: AfterNodeInvocationEvent) -> None: - """Persist state after each node completes execution.""" - self._persist(event.orchestrator) - - def _on_after_multiagent(self, event: AfterMultiAgentInvocationEvent) -> None: - """Persist final state after graph execution completes.""" - self._persist(event.orchestrator) - - def _persist(self, orchestrator: "MultiAgentBase") -> None: - """Persist the provided MultiAgentState using the configured SessionManager. - - This method is synchronized across threads/tasks to avoid write races. - - Args: - orchestrator: State to persist - """ - current_state = orchestrator.get_state_from_orchestrator() - with self._lock: - self._session_manager.write_multi_agent_json(current_state) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index b8e7f82ab..9d52a63be 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,12 +7,15 @@ via hook provider objects. """ +import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Protocol, Type, TypeVar if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) + @dataclass class BaseHookEvent: @@ -184,7 +187,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: + def invoke_callbacks(self, event: TInvokeEvent, supress_exceptions: Optional[bool] = False) -> TInvokeEvent: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and @@ -194,6 +197,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: Args: event: The event to dispatch to registered callbacks. + supress_exceptions: Except exception or not. Returns: The event dispatched to registered callbacks. @@ -203,9 +207,17 @@ def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: event = StartRequestEvent(agent=my_agent) registry.invoke_callbacks(event) ``` + """ for callback in self.get_callbacks_for(event): - callback(event) + if supress_exceptions: + try: + callback(event) + except Exception as e: + logger.exception("Hook invocation failed for %s: %s", type(event).__name__, e) + pass + else: + callback(event) return event diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index dcebbe5da..7a964cd1c 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -37,7 +37,7 @@ class NodeResult: """ # Core result data - single AgentResult, nested MultiAgentResult, or Exception - result: Union["AgentResult", "MultiAgentResult", Exception] + result: Union[AgentResult, "MultiAgentResult", Exception] # Execution metadata execution_time: int = 0 @@ -48,26 +48,18 @@ class NodeResult: accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_count: int = 0 - def get_agent_results(self) -> list["AgentResult"]: + def get_agent_results(self) -> list[AgentResult]: """Get all AgentResult objects from this node, flattened if nested.""" if isinstance(self.result, Exception): return [] # No agent results for exceptions - elif isinstance(self.result, AgentResult): + if isinstance(self.result, AgentResult): return [self.result] - # else: - # # Flatten nested results from MultiAgentResult - # flattened = [] - # for nested_node_result in self.result.results.values(): - # if isinstance(nested_node_result, NodeResult): - # flattened.extend(nested_node_result.get_agent_results()) - # return flattened - if getattr(self.result, "__class__", None) and self.result.__class__.__name__ == "AgentResult": return [self.result] # type: ignore[list-item] # If this is a nested MultiAgentResult, flatten children if hasattr(self.result, "results") and isinstance(self.result.results, dict): - flattened: list["AgentResult"] = [] + flattened: list[AgentResult] = [] for nested in self.result.results.values(): if isinstance(nested, NodeResult): flattened.extend(nested.get_agent_results()) @@ -75,6 +67,33 @@ def get_agent_results(self) -> list["AgentResult"]: return [] + def to_dict(self) -> dict[str, Any]: + """Convert NodeResult to JSON-serializable dict, ignoring state field.""" + result_data: Any = None + if isinstance(self.result, Exception): + result_data = {"type": "exception", "message": str(self.result)} + elif isinstance(self.result, AgentResult): + # Serialize AgentResult without state field + result_data = { + "type": "agent_result", + "stop_reason": self.result.stop_reason, + "message": self.result.message, # Message type is JSON serializable + # Skip metrics and state - not JSON serializable + } + elif hasattr(self.result, "to_dict"): + result_data = self.result.to_dict() + else: + result_data = str(self.result) + + return { + "result": result_data, + "execution_time": self.execution_time, + "status": self.status.value, + "accumulated_usage": dict(self.accumulated_usage), + "accumulated_metrics": dict(self.accumulated_metrics), + "execution_count": self.execution_count, + } + @dataclass class MultiAgentResult: @@ -92,6 +111,17 @@ class MultiAgentResult: execution_count: int = 0 execution_time: int = 0 + def to_dict(self) -> dict[str, Any]: + """Convert MultiAgentResult to JSON-serializable dict.""" + return { + "status": self.status.value, + "results": {k: v.to_dict() for k, v in self.results.items()}, + "accumulated_usage": dict(self.accumulated_usage), + "accumulated_metrics": dict(self.accumulated_metrics), + "execution_count": self.execution_count, + "execution_time": self.execution_time, + } + class MultiAgentBase(ABC): """Base class for multi-agent helpers. @@ -135,65 +165,30 @@ def execute() -> MultiAgentResult: future = executor.submit(execute) return future.result() - def _call_hook_safely(self, event_object: object) -> None: - """Invoke hook callbacks and swallow hook errors. - - Args: - event_object: The event to dispatch to registered callbacks. - """ - try: - self.hooks.invoke_callbacks(event_object) # type: ignore - except Exception as e: - logger.exception("Hook invocation failed for %s: %s", type(event_object).__name__, e) - @abstractmethod - def get_state_from_orchestrator(self) -> dict: + def serialize_state(self) -> dict: """Return a JSON-serializable snapshot of the orchestrator state.""" raise NotImplementedError @abstractmethod - def apply_state_from_dict(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError - def summarize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: - """Summarize node result for efficient persistence. + def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: + """Serialize node result for persistence. Args: - raw: Raw node result to summarize + raw: Raw node result to serialize Returns: - Normalized dict with 'agent_outputs' key containing string results + JSON-serializable dict representation """ + if isinstance(raw, dict): + return raw + + if hasattr(raw, "to_dict") and callable(raw.to_dict): + return raw.to_dict() - def _extract_text_from_agent_result(ar: Any) -> str: - try: - msg = getattr(ar, "message", None) - if isinstance(msg, dict): - blocks = msg.get("content") or [] - texts = [] - for b in blocks: - t = b.get("text") - if t: - texts.append(str(t)) - if texts: - return "\n".join(texts) - return str(ar) - except Exception: - return str(ar) - - # If it's a NodeResult with AgentResults, flatten - if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): - try: - ars = raw.get_agent_results() # list[AgentResult] - if ars: - return {"agent_outputs": [_extract_text_from_agent_result(a) for a in ars]} - except Exception: - pass - - # If already normalized - if isinstance(raw, dict) and isinstance(raw.get("agent_outputs"), list): - return {"agent_outputs": [str(x) for x in raw["agent_outputs"]]} - - # Fallback + # Fallback for strings and other types return {"agent_outputs": [str(raw)]} diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 9f66c30ff..a56f3d966 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -14,7 +14,6 @@ - Supports nested graphs (Graph as a node in another Graph) """ -import ast import asyncio import copy import logging @@ -30,10 +29,8 @@ from ..experimental.multiagent_hooks import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - BeforeMultiAgentInvocationEvent, MultiAgentInitializationEvent, ) -from ..experimental.multiagent_hooks.persistence_hooks import PersistentHook from ..hooks import HookRegistry from ..session import SessionManager from ..telemetry import get_tracer @@ -103,23 +100,6 @@ def should_continue( return True, "Continuing" - @staticmethod - def _normalize_persisted_like_dict(data: dict[str, Any]) -> dict[str, Any]: - """Normalize dictionary data to standard agent_outputs format. - - Args: - data: Dictionary containing result data - - Returns: - Normalized dict with 'agent_outputs' key - """ - if "agent_outputs" in data and isinstance(data["agent_outputs"], list): - return {"agent_outputs": [str(x) for x in data["agent_outputs"]]} - - if "summary" in data: - return {"agent_outputs": [str(data["summary"])]} - return {"agent_outputs": [str(data)]} - @dataclass class GraphResult(MultiAgentResult): @@ -356,11 +336,6 @@ def build(self) -> "Graph": # Validate entry points and check for cycles self._validate_graph() - # Build hook auto-inject - hooks = HookRegistry() - if self._session_manager is not None: - hooks.add_hook(PersistentHook(session_manager=self._session_manager)) - return Graph( nodes=self.nodes.copy(), edges=self.edges.copy(), @@ -370,7 +345,6 @@ def build(self) -> "Graph": node_timeout=self._node_timeout, reset_on_revisit=self._reset_on_revisit, session_manager=self._session_manager, - hooks=hooks, ) def _validate_graph(self) -> None: @@ -386,30 +360,16 @@ def _validate_graph(self) -> None: logger.warning("Graph without execution limits may run indefinitely if cycles exist") -def _iterate_previous_outputs(raw: Any) -> list[tuple[str, str]]: +def _iterate_previous_outputs(raw: NodeResult | dict[str, Any]) -> list[tuple[str, str]]: """Return a list of (agent_name, text) from NodeResult or persisted dict.""" # Live NodeResult if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): - try: - return [(getattr(r, "agent_name", "Agent"), str(r)) for r in raw.get_agent_results()] - except Exception: - pass + return [(getattr(r, "agent_name", "Agent"), str(r)) for r in raw.get_agent_results()] - # Convert string to dict if possible - data = raw - if isinstance(raw, str): - try: - data = ast.literal_eval(raw) - except Exception: - return [("Agent", raw)] - - # Extract from dict - if isinstance(data, dict): - if isinstance(data.get("agent_outputs"), list): - return [("Agent", str(x)) for x in data["agent_outputs"]] - if "summary" in data: - return [("Agent", str(data["summary"]))] + if isinstance(raw, dict) and "agent_outputs" in raw: + return [("Agent", str(x)) for x in raw["agent_outputs"]] + # Fallback return [("Agent", str(raw))] @@ -457,14 +417,18 @@ def __init__( self.tracer = get_tracer() self.session_manager = session_manager self.hooks = hooks or HookRegistry() + if self.session_manager is not None: + self.hooks.add_hook(self.session_manager) - # Concurrncy lock + # Concurrent lock self._lock = asyncio.Lock() # Resume flag self._resume_from_persisted = False self._resume_next_nodes: list[GraphNode] = [] self._load_and_apply_persisted_state() + if not self._resume_from_persisted and self.state.status == Status.PENDING: + self.hooks.invoke_callbacks(MultiAgentInitializationEvent(orchestrator=self), supress_exceptions=True) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -504,12 +468,6 @@ async def invoke_async( logger.debug("task=<%s> | starting graph execution", task) - if not self._resume_from_persisted and self.state.status == Status.PENDING: - # TODO: to check if we need do something on BeforeGraphInvocationEvent - self._call_hook_safely(MultiAgentInitializationEvent(orchestrator=self)) - - self._call_hook_safely(BeforeMultiAgentInvocationEvent(orchestrator=self)) - if not self._resume_from_persisted: start_time = time.time() # Initialize state @@ -554,7 +512,7 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self._call_hook_safely(AfterMultiAgentInvocationEvent(orchestrator=self)) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(orchestrator=self), supress_exceptions=True) self._resume_from_persisted = False self._resume_next_nodes.clear() return self._build_result() @@ -643,6 +601,29 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" # Reset the node's state if reset_on_revisit is enabled and it's being revisited + + async def _record_failure(exception: Exception) -> None: + execution_time = round((time.time() - start_time) * 1000) + fail_result = NodeResult( + result=exception, + execution_time=execution_time, + status=Status.FAILED, + accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), + accumulated_metrics=Metrics(latencyMs=execution_time), + execution_count=1, + ) + async with self._lock: + node.execution_status = Status.FAILED + node.result = fail_result + node.execution_time = execution_time + self.state.failed_nodes.add(node) + self.state.results[node.node_id] = fail_result + + self.hooks.invoke_callbacks( + AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id), + supress_exceptions=True, + ) + if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() @@ -719,31 +700,23 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time ) - self._call_hook_safely(AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id)) - - except Exception as e: - logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e) - execution_time = round((time.time() - start_time) * 1000) - - fail_result = NodeResult( - result=e, - execution_time=execution_time, - status=Status.FAILED, - accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), - accumulated_metrics=Metrics(latencyMs=execution_time), - execution_count=1, + self.hooks.invoke_callbacks( + AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id), supress_exceptions=True ) - async with self._lock: - node.execution_status = Status.FAILED - node.result = fail_result - node.execution_time = execution_time - self.state.failed_nodes.add(node) - self.state.results[node.node_id] = fail_result - - # Need to persist failure multiagent_state too - self._call_hook_safely(AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id)) + except asyncio.TimeoutError: + timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" + logger.exception( + "node=<%s>, timeout=<%s>s | node execution timed out after timeout", + node.node_id, + self.node_timeout, + ) + await _record_failure(asyncio.TimeoutError(timeout_msg)) + raise + except Exception as e: + logger.error("node_id=<%s>, error=<%s> | node failed", node.node_id, e) + await _record_failure(e) raise def _accumulate_metrics(self, node_result: NodeResult) -> None: @@ -855,8 +828,7 @@ def _to_dict(self) -> dict[str, Any]: "status": status_str, "completed_nodes": [n.node_id for n in self.state.completed_nodes], "node_results": { - k: self.summarize_node_result_for_persist(v) # CHANGED: normalized text-only outputs - for k, v in (self.state.results or {}).items() + k: self.serialize_node_result_for_persist(v) for k, v in (self.state.results or {}).items() }, "next_node_to_execute": next_nodes, "current_task": self.state.task, @@ -869,13 +841,13 @@ def _load_and_apply_persisted_state(self) -> None: try: json_data = self.session_manager.read_multi_agent_json() except Exception as e: - logger.warning("Skipping resume; failed to load state: %s", e) - return + logger.exception("Resume failed; failed to load state: %s", e) + raise if not json_data: return try: - self.apply_state_from_dict(json_data) + self.deserialize_state(json_data) self._resume_from_persisted = True next_node_ids = json_data.get("next_node_to_execute") or [] @@ -925,7 +897,7 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: return [node for node in self.entry_points if node not in completed_nodes] - def get_state_from_orchestrator(self) -> dict: + def serialize_state(self) -> dict: """Return a JSON-serializable snapshot of the orchestrator state. Returns: @@ -933,7 +905,7 @@ def get_state_from_orchestrator(self) -> dict: """ return self._to_dict() - def apply_state_from_dict(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: """Restore orchestrator state from a session dict. Args: diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index c24449e6b..c697d343c 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -28,11 +28,8 @@ from ..experimental.multiagent_hooks import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - BeforeMultiAgentInvocationEvent, - BeforeNodeInvocationEvent, MultiAgentInitializationEvent, ) -from ..experimental.multiagent_hooks.persistence_hooks import PersistentHook from ..hooks import HookRegistry from ..session import SessionManager from ..telemetry import get_tracer @@ -254,7 +251,7 @@ def __init__( self.session_manager = session_manager self.hooks = hooks or HookRegistry() if self.session_manager is not None: - self.hooks.add_hook(PersistentHook(session_manager=self.session_manager)) + self.hooks.add_hook(self.session_manager) self._setup_swarm(nodes) self._inject_swarm_tools() @@ -264,6 +261,8 @@ def __init__( self._resume_from_completed = False self._load_and_apply_persisted_state() + if not self._resume_from_persisted and self.state.completion_status == Status.PENDING: + self.hooks.invoke_callbacks(MultiAgentInitializationEvent(orchestrator=self), supress_exceptions=True) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -307,12 +306,6 @@ async def invoke_async( logger.debug("Returning persisted COMPLETED result without re-execution.") return self._build_result() - if not self._resume_from_persisted and self.state.completion_status == Status.PENDING: - self._call_hook_safely(MultiAgentInitializationEvent(orchestrator=self)) - - # Customizable Before GraphInvocation. - self._call_hook_safely(BeforeMultiAgentInvocationEvent(orchestrator=self)) - # If resume if not self._resume_from_persisted: initial_node = self._initial_node() @@ -348,7 +341,7 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self._call_hook_safely(AfterMultiAgentInvocationEvent(orchestrator=self)) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(orchestrator=self), supress_exceptions=True) self._resume_from_persisted = False self._resume_from_completed = False @@ -502,7 +495,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st # Persist handoff msg incase we lose it. if self.session_manager is not None: try: - self.session_manager.write_multi_agent_json(self.get_state_from_orchestrator()) + self.session_manager.write_multi_agent_json(self.serialize_state()) except Exception as e: logger.warning("Failed to persist swarm state after handoff: %s", e) @@ -609,9 +602,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: self.state.completion_status = Status.FAILED break - # TODO: BEFORE NODE INVOCATION START HERE - self._call_hook_safely(BeforeNodeInvocationEvent(self, next_node_to_execute=current_node.node_id)) - logger.debug( "current_node=<%s>, iteration=<%d> | executing node", current_node.node_id, @@ -630,7 +620,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: logger.debug("node=<%s> | node execution completed", current_node.node_id) - self._call_hook_safely(AfterNodeInvocationEvent(self, executed_node=current_node.node_id)) + self.hooks.invoke_callbacks( + AfterNodeInvocationEvent(self, executed_node=current_node.node_id), supress_exceptions=True + ) # Check if the current node is still the same after execution # If it is, then no handoff occurred and we consider the swarm complete @@ -658,7 +650,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: self.state.completion_status = Status.FAILED elapsed_time = time.time() - self.state.start_time - self._call_hook_safely(AfterMultiAgentInvocationEvent(orchestrator=self)) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(orchestrator=self), supress_exceptions=True) logger.debug("status=<%s> | swarm execution completed", self.state.completion_status) logger.debug( "node_history_length=<%d>, time=<%s>s | metrics", @@ -736,7 +728,9 @@ async def _execute_node( self.state.results[node_name] = node_result # Persist failure here - self._call_hook_safely(AfterNodeInvocationEvent(self, executed_node=node_name)) + self.hooks.invoke_callbacks( + AfterNodeInvocationEvent(self, executed_node=node_name), supress_exceptions=True + ) raise @@ -781,7 +775,7 @@ def _load_and_apply_persisted_state(self) -> None: try: # saved is a dict status = saved.get("status") - self.apply_state_from_dict(saved) + self.deserialize_state(saved) self._resume_from_persisted = True if status in (Status.COMPLETED.value, "completed"): @@ -798,7 +792,7 @@ def _load_and_apply_persisted_state(self) -> None: except Exception as e: logger.exception("Failed to hydrate swarm from persisted state: %s", e) - def get_state_from_orchestrator(self) -> dict: + def serialize_state(self) -> dict: """Return a JSON-serializable snapshot of the orchestrator state. Returns: @@ -815,7 +809,7 @@ def get_state_from_orchestrator(self) -> dict: "status": status_str, "completed_nodes": [n.node_id for n in self.state.node_history], "node_results": { - k: self.summarize_node_result_for_persist(v) for k, v in (self.state.results or {}).items() + k: self.serialize_node_result_for_persist(v) for k, v in (self.state.results or {}).items() }, "next_node_to_execute": next_nodes, "current_task": self.state.task, @@ -826,7 +820,7 @@ def get_state_from_orchestrator(self) -> dict: }, } - def apply_state_from_dict(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: """Restore orchestrator state from a session dict. Args: diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 915769da0..864d85b07 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -323,15 +323,10 @@ def write_multi_agent_json(self, state: dict[str, Any]) -> None: state_key = self._join_key(session_prefix, "multi_agent_state.json") self._write_s3_object(state_key, state) - # Touch updated_at on session.json (best-effort) session_key = self._join_key(session_prefix, "session.json") - try: - metadata = self._read_s3_object(session_key) or {} - metadata["updated_at"] = datetime.now(timezone.utc).isoformat() - self._write_s3_object(session_key, metadata) - except SessionException: - # If session.json is missing or unreadable, don't fail persistence - logger.warning("Could not update session.json updated_at for session %s", self.session_id) + metadata = self._read_s3_object(session_key) or {} + metadata["updated_at"] = datetime.now(timezone.utc).isoformat() + self._write_s3_object(session_key, metadata) def read_multi_agent_json(self) -> dict[str, Any]: """Read multi-agent state from S3. diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index b14f2d467..d390ffef2 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -1,8 +1,14 @@ """Session manager interface for agent session management.""" +import threading from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ..experimental.multiagent_hooks.multiagent_events import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + MultiAgentInitializationEvent, +) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry from ..types.content import Message @@ -10,6 +16,7 @@ if TYPE_CHECKING: from ..agent.agent import Agent + from ..multiagent.base import MultiAgentBase class SessionManager(HookProvider, ABC): @@ -28,6 +35,7 @@ def __init__(self, session_type: SessionType = SessionType.AGENT) -> None: session_type: Type of session (AGENT or MULTI_AGENT) """ self.session_type: SessionType = session_type + self._lock = threading.RLock() def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for persisting the agent to the session.""" @@ -44,6 +52,17 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: # After an agent was invoked, sync it with the session to capture any conversation manager state updates registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) + elif self.session_type == SessionType.MULTI_AGENT: + registry.add_callback( + MultiAgentInitializationEvent, lambda event: self._persist_multi_agent_state(event.orchestrator) + ) + registry.add_callback( + AfterNodeInvocationEvent, lambda event: self._persist_multi_agent_state(event.orchestrator) + ) + registry.add_callback( + AfterMultiAgentInvocationEvent, lambda event: self._persist_multi_agent_state(event.orchestrator) + ) + @abstractmethod def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwargs: Any) -> None: """Redact the message most recently appended to the agent in the session. @@ -82,6 +101,16 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: **kwargs: Additional keyword arguments for future extensibility. """ + def _persist_multi_agent_state(self, orchestrator: "MultiAgentBase") -> None: + """Thread-safe persistence of multi-agent state. + + Args: + orchestrator: Multi-agent orchestrator to persist + """ + with self._lock: + state = orchestrator.serialize_state() + self.write_multi_agent_json(state) + # Multiagent abstract functions @abstractmethod def write_multi_agent_json(self, state: dict[str, Any]) -> None: diff --git a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py b/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py index f983c0602..4ecc92d4d 100644 --- a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py +++ b/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py @@ -7,8 +7,6 @@ from strands.experimental.multiagent_hooks.multiagent_events import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - BeforeMultiAgentInvocationEvent, - BeforeNodeInvocationEvent, MultiAgentInitializationEvent, ) from strands.hooks.registry import BaseHookEvent @@ -38,48 +36,6 @@ def test_multi_agent_initialization_event_with_invocation_state(orchestrator): assert event.invocation_state == invocation_state -def test_before_multi_agent_invocation_event_with_orchestrator_only(orchestrator): - """Test BeforeMultiAgentInvocationEvent creation with orchestrator only.""" - event = BeforeMultiAgentInvocationEvent(orchestrator=orchestrator) - - assert event.orchestrator is orchestrator - assert event.invocation_state is None - assert isinstance(event, BaseHookEvent) - - -def test_before_multi_agent_invocation_event_with_invocation_state(orchestrator): - """Test BeforeMultiAgentInvocationEvent creation with invocation state.""" - invocation_state = {"config": "test"} - event = BeforeMultiAgentInvocationEvent(orchestrator=orchestrator, invocation_state=invocation_state) - - assert event.orchestrator is orchestrator - assert event.invocation_state == invocation_state - - -def test_before_node_invocation_event_with_required_fields(orchestrator): - """Test BeforeNodeInvocationEvent creation with required fields.""" - next_node = "node_1" - event = BeforeNodeInvocationEvent(orchestrator=orchestrator, next_node_to_execute=next_node) - - assert event.orchestrator is orchestrator - assert event.next_node_to_execute == next_node - assert event.invocation_state is None - assert isinstance(event, BaseHookEvent) - - -def test_before_node_invocation_event_with_invocation_state(orchestrator): - """Test BeforeNodeInvocationEvent creation with invocation state.""" - next_node = "node_2" - invocation_state = {"step": 1} - event = BeforeNodeInvocationEvent( - orchestrator=orchestrator, next_node_to_execute=next_node, invocation_state=invocation_state - ) - - assert event.orchestrator is orchestrator - assert event.next_node_to_execute == next_node - assert event.invocation_state == invocation_state - - def test_after_node_invocation_event_with_required_fields(orchestrator): """Test AfterNodeInvocationEvent creation with required fields.""" executed_node = "node_1" diff --git a/tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py b/tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py deleted file mode 100644 index 30ca9d214..000000000 --- a/tests/strands/experimental/multiagent_hooks/test_persistence_hooks.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Tests for multi-agent session persistence hook implementation.""" - -from unittest.mock import Mock, call - -import pytest - -from strands.experimental.multiagent_hooks.multiagent_events import ( - AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - BeforeMultiAgentInvocationEvent, - BeforeNodeInvocationEvent, - MultiAgentInitializationEvent, -) -from strands.experimental.multiagent_hooks.persistence_hooks import PersistentHook -from strands.hooks.registry import HookRegistry - - -@pytest.fixture -def session_manager(): - """Mock session manager.""" - return Mock() - - -@pytest.fixture -def orchestrator(): - """Mock orchestrator.""" - mock_orchestrator = Mock() - mock_orchestrator.get_state_from_orchestrator.return_value = {"state": "test"} - return mock_orchestrator - - -@pytest.fixture -def hook(session_manager): - """PersistentHook instance.""" - return PersistentHook(session_manager) - - -def test_initialization(session_manager): - """Test hook initialization.""" - hook = PersistentHook(session_manager) - - assert hook._session_manager is session_manager - assert hasattr(hook, "_lock") - - -def test_register_hooks(hook): - """Test hook registration with registry.""" - registry = Mock(spec=HookRegistry) - - hook.register_hooks(registry) - - expected_calls = [ - call(MultiAgentInitializationEvent, hook._on_initialization), - call(BeforeMultiAgentInvocationEvent, hook._on_before_multiagent), - call(BeforeNodeInvocationEvent, hook._on_before_node), - call(AfterNodeInvocationEvent, hook._on_after_node), - call(AfterMultiAgentInvocationEvent, hook._on_after_multiagent), - ] - registry.add_callback.assert_has_calls(expected_calls) - - -def test_on_initialization_persists_state(hook, orchestrator): - """Test initialization event triggers persistence.""" - event = MultiAgentInitializationEvent(orchestrator=orchestrator) - - hook._on_initialization(event) - - orchestrator.get_state_from_orchestrator.assert_called_once() - hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) - - -def test_on_before_multiagent_does_nothing(hook, orchestrator): - """Test before multiagent event does nothing.""" - event = BeforeMultiAgentInvocationEvent(orchestrator=orchestrator) - - hook._on_before_multiagent(event) - - orchestrator.get_state_from_orchestrator.assert_not_called() - hook._session_manager.write_multi_agent_json.assert_not_called() - - -def test_on_before_node_does_nothing(hook, orchestrator): - """Test before node event does nothing.""" - event = BeforeNodeInvocationEvent(orchestrator=orchestrator, next_node_to_execute="node_1") - - hook._on_before_node(event) - - orchestrator.get_state_from_orchestrator.assert_not_called() - hook._session_manager.write_multi_agent_json.assert_not_called() - - -def test_on_after_node_persists_state(hook, orchestrator): - """Test after node event triggers persistence.""" - event = AfterNodeInvocationEvent(orchestrator=orchestrator, executed_node="node_1") - - hook._on_after_node(event) - - orchestrator.get_state_from_orchestrator.assert_called_once() - hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) - - -def test_on_after_multiagent_persists_state(hook, orchestrator): - """Test after multiagent event triggers persistence.""" - event = AfterMultiAgentInvocationEvent(orchestrator=orchestrator) - - hook._on_after_multiagent(event) - - orchestrator.get_state_from_orchestrator.assert_called_once() - hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) - - -def test_persist_thread_safety(hook, orchestrator): - """Test that persistence operations are thread-safe.""" - hook._lock = Mock() - hook._lock.__enter__ = Mock(return_value=hook._lock) - hook._lock.__exit__ = Mock(return_value=None) - - hook._persist(orchestrator) - - hook._lock.__enter__.assert_called_once() - hook._lock.__exit__.assert_called_once() - orchestrator.get_state_from_orchestrator.assert_called_once() - hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) - - -def test_persist_gets_state_and_writes(hook, orchestrator): - """Test persist method gets state and writes to session manager.""" - hook._persist(orchestrator) - - orchestrator.get_state_from_orchestrator.assert_called_once() - hook._session_manager.write_multi_agent_json.assert_called_once_with({"state": "test"}) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index a5a4a10aa..19c886c1b 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -156,10 +156,10 @@ class CompleteMultiAgent(MultiAgentBase): async def invoke_async(self, task: str) -> MultiAgentResult: return MultiAgentResult(results={}) - def get_state_from_orchestrator(self) -> dict: + def serialize_state(self) -> dict: return {} - def apply_state_from_dict(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: pass # Should not raise an exception - __call__ is provided by base class @@ -184,10 +184,10 @@ async def invoke_async(self, task, invocation_state, **kwargs): status=Status.COMPLETED, results={"test": NodeResult(result=Exception("test"), status=Status.COMPLETED)} ) - def get_state_from_orchestrator(self) -> dict: + def serialize_state(self) -> dict: return {} - def apply_state_from_dict(self, payload: dict) -> None: + def deserialize_state(self, payload: dict) -> None: pass agent = TestMultiAgent() @@ -202,40 +202,60 @@ def apply_state_from_dict(self, payload: dict) -> None: assert result.status == Status.COMPLETED -def test_summarize_node_result_for_persist(agent_result): - """Test summarize_node_result_for_persist method.""" - from unittest.mock import Mock +def test_node_result_to_dict(agent_result): + """Test NodeResult to_dict method.""" + # Test with AgentResult + node_result = NodeResult(result=agent_result, execution_time=100, status=Status.COMPLETED) + result_dict = node_result.to_dict() - agent = Mock(spec=MultiAgentBase) + assert result_dict["execution_time"] == 100 + assert result_dict["status"] == "completed" + assert result_dict["result"]["type"] == "agent_result" + assert result_dict["result"]["stop_reason"] == agent_result.stop_reason + assert result_dict["result"]["message"] == agent_result.message - # Test with NodeResult containing AgentResult + # Test with Exception + exception_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + result_dict = exception_result.to_dict() + + assert result_dict["result"]["type"] == "exception" + assert result_dict["result"]["message"] == "Test error" + assert result_dict["status"] == "failed" + + +def test_multi_agent_result_to_dict(agent_result): + """Test MultiAgentResult to_dict method.""" node_result = NodeResult(result=agent_result) - summary = MultiAgentBase.summarize_node_result_for_persist(agent, node_result) - assert "agent_outputs" in summary - assert isinstance(summary["agent_outputs"], list) + multi_result = MultiAgentResult(status=Status.COMPLETED, results={"test_node": node_result}, execution_time=200) - # Test with already normalized dict - normalized = {"agent_outputs": ["test1", "test2"]} - summary = MultiAgentBase.summarize_node_result_for_persist(agent, normalized) - assert summary == {"agent_outputs": ["test1", "test2"]} + result_dict = multi_result.to_dict() - # Test fallback case - summary = MultiAgentBase.summarize_node_result_for_persist(agent, "simple string") - assert summary == {"agent_outputs": ["simple string"]} + assert result_dict["status"] == "completed" + assert result_dict["execution_time"] == 200 + assert "test_node" in result_dict["results"] + assert result_dict["results"]["test_node"]["result"]["type"] == "agent_result" -def test_call_hook_safely(): - """Test _call_hook_safely method handles exceptions.""" +def test_serialize_node_result_for_persist(agent_result): + """Test serialize_node_result_for_persist method.""" from unittest.mock import Mock agent = Mock(spec=MultiAgentBase) - agent.hooks = Mock() - event = Mock() - # Test successful hook call - MultiAgentBase._call_hook_safely(agent, event) - agent.hooks.invoke_callbacks.assert_called_once_with(event) + # Test with NodeResult containing AgentResult + node_result = NodeResult(result=agent_result) + serialized = MultiAgentBase.serialize_node_result_for_persist(agent, node_result) + + # Should return the to_dict() result + assert "result" in serialized + assert "execution_time" in serialized + assert "status" in serialized - # Test hook exception handling - agent.hooks.invoke_callbacks.side_effect = Exception("Hook error") - MultiAgentBase._call_hook_safely(agent, event) + # Test with already normalized dict + normalized = {"agent_outputs": ["test1", "test2"]} + serialized = MultiAgentBase.serialize_node_result_for_persist(agent, normalized) + assert serialized == {"agent_outputs": ["test1", "test2"]} + + # Test fallback case + serialized = MultiAgentBase.serialize_node_result_for_persist(agent, "simple string") + assert serialized == {"agent_outputs": ["simple string"]} diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 9e98524fe..fedcf468f 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1341,7 +1341,7 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): graph = builder.build() # Test get_state_from_orchestrator - state = graph.get_state_from_orchestrator() + state = graph.serialize_state() assert state["type"] == "graph" assert "status" in state assert "completed_nodes" in state @@ -1357,7 +1357,7 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): "next_node_to_execute": ["test_node"], } - graph.apply_state_from_dict(persisted_state) + graph.deserialize_state(persisted_state) assert graph.state.task == "persisted task" # Execute graph to test persistence integration @@ -1369,7 +1369,7 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): assert "test_node" in result.results # Test state serialization after execution - final_state = graph.get_state_from_orchestrator() + final_state = graph.serialize_state() assert final_state["status"] == "completed" assert len(final_state["completed_nodes"]) == 1 assert "test_node" in final_state["node_results"] diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index dfb0b19e0..6fa5edebd 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -588,7 +588,7 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): swarm = Swarm([agent], session_manager=session_manager) # Test get_state_from_orchestrator - state = swarm.get_state_from_orchestrator() + state = swarm.serialize_state() assert state["type"] == "swarm" assert "status" in state assert "completed_nodes" in state @@ -606,7 +606,7 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, } - swarm.apply_state_from_dict(persisted_state) + swarm.deserialize_state(persisted_state) assert swarm.state.task == "persisted task" assert swarm.state.handoff_message == "test handoff" assert swarm.shared_context.context["test_agent"]["key"] == "value" @@ -620,7 +620,7 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): assert "test_agent" in result.results # Test state serialization after execution - final_state = swarm.get_state_from_orchestrator() + final_state = swarm.serialize_state() assert final_state["status"] == "completed" assert len(final_state["completed_nodes"]) == 1 assert "test_agent" in final_state["node_results"] From 03805677b8c3a3ed23073aca6ba71a5c16fed6cb Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 1 Oct 2025 16:08:46 -0400 Subject: [PATCH 05/27] Update src/strands/multiagent/base.py Co-authored-by: Nick Clegg --- src/strands/multiagent/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 7a964cd1c..8f0983a95 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -166,7 +166,7 @@ def execute() -> MultiAgentResult: return future.result() @abstractmethod - def serialize_state(self) -> dict: + def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" raise NotImplementedError From 93909d09f0a6511a2d656511259899e8dcf44fd8 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Wed, 1 Oct 2025 16:24:31 -0400 Subject: [PATCH 06/27] Update src/strands/multiagent/base.py Co-authored-by: Nick Clegg --- src/strands/multiagent/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 8f0983a95..4c83d2e76 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -171,7 +171,7 @@ def serialize_state(self) -> dict[str, Any]: raise NotImplementedError @abstractmethod - def deserialize_state(self, payload: dict) -> None: + def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError From 023cfbad4592ad67cae0abda32b263783edc8310 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Thu, 2 Oct 2025 10:40:36 -0400 Subject: [PATCH 07/27] Update src/strands/session/session_manager.py Co-authored-by: Nick Clegg --- src/strands/session/session_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index d390ffef2..95b39ecc9 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -119,7 +119,6 @@ def write_multi_agent_json(self, state: dict[str, Any]) -> None: Args: state: Multi-agent state dictionary to persist """ - raise NotImplementedError @abstractmethod def read_multi_agent_json(self) -> dict[str, Any]: From d064038c41fcc7d1e3e1089f2d23080f15acfdcf Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Thu, 2 Oct 2025 10:40:47 -0400 Subject: [PATCH 08/27] Update src/strands/session/session_manager.py Co-authored-by: Nick Clegg --- src/strands/session/session_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 95b39ecc9..03e96d0db 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -127,4 +127,3 @@ def read_multi_agent_json(self) -> dict[str, Any]: Returns: Multi-agent state dictionary or empty dict if not found """ - raise NotImplementedError From b00f58d7ab74cbeb8ff6e68a94157127c775ade9 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 6 Oct 2025 16:34:58 -0400 Subject: [PATCH 09/27] feat: add restricted type check to serialization/deserialization function to multiagent session persistence. --- src/strands/multiagent/base.py | 100 ++++++++++++++++++-------- src/strands/multiagent/graph.py | 60 ++++++++-------- src/strands/multiagent/swarm.py | 3 +- tests/strands/multiagent/test_base.py | 22 ++---- 4 files changed, 108 insertions(+), 77 deletions(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 4c83d2e76..1869c61d2 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -9,10 +9,11 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Any, Union +from typing import Any, Literal, Union, cast from ..agent import AgentResult -from ..types.content import ContentBlock +from ..telemetry.metrics import EventLoopMetrics +from ..types.content import ContentBlock, Message from ..types.event_loop import Metrics, Usage logger = logging.getLogger(__name__) @@ -52,48 +53,83 @@ def get_agent_results(self) -> list[AgentResult]: """Get all AgentResult objects from this node, flattened if nested.""" if isinstance(self.result, Exception): return [] # No agent results for exceptions - if isinstance(self.result, AgentResult): + elif isinstance(self.result, AgentResult): return [self.result] - if getattr(self.result, "__class__", None) and self.result.__class__.__name__ == "AgentResult": - return [self.result] # type: ignore[list-item] - # If this is a nested MultiAgentResult, flatten children - if hasattr(self.result, "results") and isinstance(self.result.results, dict): - flattened: list[AgentResult] = [] - for nested in self.result.results.values(): - if isinstance(nested, NodeResult): - flattened.extend(nested.get_agent_results()) + else: + # Flatten nested results from MultiAgentResult + flattened = [] + for nested_node_result in self.result.results.values(): + flattened.extend(nested_node_result.get_agent_results()) return flattened - return [] - def to_dict(self) -> dict[str, Any]: """Convert NodeResult to JSON-serializable dict, ignoring state field.""" - result_data: Any = None if isinstance(self.result, Exception): - result_data = {"type": "exception", "message": str(self.result)} + result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} elif isinstance(self.result, AgentResult): # Serialize AgentResult without state field result_data = { "type": "agent_result", "stop_reason": self.result.stop_reason, - "message": self.result.message, # Message type is JSON serializable - # Skip metrics and state - not JSON serializable + "message": self.result.message, } - elif hasattr(self.result, "to_dict"): + elif isinstance(self.result, MultiAgentResult): result_data = self.result.to_dict() else: - result_data = str(self.result) + raise TypeError(f"Unsupported NodeResult.result type for serialization: {type(self.result).__name__}") return { "result": result_data, "execution_time": self.execution_time, "status": self.status.value, - "accumulated_usage": dict(self.accumulated_usage), - "accumulated_metrics": dict(self.accumulated_metrics), - "execution_count": self.execution_count, } + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeResult": + """Rehydrate a NodeResult from persisted JSON.""" + if "result" not in data: + raise TypeError("NodeResult.from_dict: missing 'result'") + raw = data["result"] + + result: Union[AgentResult, "MultiAgentResult", Exception] + if isinstance(raw, dict) and raw.get("type") == "agent_result": + result = _agent_result_from_persisted(raw) + elif isinstance(raw, dict) and raw.get("type") == "exception": + result = Exception(str(raw.get("message", "node failed"))) + elif isinstance(raw, dict) and ("results" in raw): + result = MultiAgentResult.from_dict(raw) + else: + raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") + + return cls( + result=result, + execution_time=int(data.get("execution_time", 0)), + status=Status(data.get("status", Status.PENDING.value)), + ) + + +def _agent_result_from_persisted(data: dict[str, Any]) -> AgentResult: + """Rehydrate a minimal AgentResult from persisted JSON. + + Expected shape: + {"type": "agent_result", "message": , "stop_reason": } + """ + if data.get("type") != "agent_result": + raise TypeError(f"_agent_result_from_persisted: unexpected type {data.get('type')!r}") + + message = cast(Message, data.get("message")) + stop_reason = cast( + Literal["content_filtered", "end_turn", "guardrail_intervened", "max_tokens", "stop_sequence", "tool_use"], + data.get("stop_reason"), + ) + + try: + return AgentResult(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + except Exception: + logger.debug("AgentResult constructor failed during rehydrating") + raise + @dataclass class MultiAgentResult: @@ -122,6 +158,15 @@ def to_dict(self) -> dict[str, Any]: "execution_time": self.execution_time, } + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": + """Rehydrate a MultiAgentResult from persisted JSON.""" + multiagent_result = cls( + status=Status(data.get("status", Status.PENDING.value)), + ) + multiagent_result.results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} + return multiagent_result + class MultiAgentBase(ABC): """Base class for multi-agent helpers. @@ -184,11 +229,6 @@ def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: Returns: JSON-serializable dict representation """ - if isinstance(raw, dict): - return raw - - if hasattr(raw, "to_dict") and callable(raw.to_dict): - return raw.to_dict() - - # Fallback for strings and other types - return {"agent_outputs": [str(raw)]} + if not isinstance(raw, NodeResult): + raise TypeError(f"serialize_node_result_for_persist expects NodeResult, got {type(raw).__name__}") + return raw.to_dict() diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index a56f3d966..cc4e359f1 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -360,19 +360,6 @@ def _validate_graph(self) -> None: logger.warning("Graph without execution limits may run indefinitely if cycles exist") -def _iterate_previous_outputs(raw: NodeResult | dict[str, Any]) -> list[tuple[str, str]]: - """Return a list of (agent_name, text) from NodeResult or persisted dict.""" - # Live NodeResult - if hasattr(raw, "get_agent_results") and callable(raw.get_agent_results): - return [(getattr(r, "agent_name", "Agent"), str(r)) for r in raw.get_agent_results()] - - if isinstance(raw, dict) and "agent_outputs" in raw: - return [("Agent", str(x)) for x in raw["agent_outputs"]] - - # Fallback - return [("Agent", str(raw))] - - class Graph(MultiAgentBase): """Directed Graph multi-agent orchestration with configurable revisit behavior.""" @@ -580,8 +567,11 @@ def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["G return newly_ready def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool: - """Check if a node is ready considering conditional edges.""" - # Get incoming edges to this node + """Check if a node is ready considering conditional edges. + + A node is ready iff ALL incoming deps are completed AND their edge conditions pass, + and at least one of those deps was completed in THIS batch (so it's newly ready). + """ incoming_edges = [edge for edge in self.edges if edge.to_node == node] # Check if at least one incoming edge condition is satisfied @@ -778,8 +768,9 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: for dep_id, prev_result in dependency_results.items(): node_input.append(ContentBlock(text=f"\nFrom {dep_id}:")) - for agent_name, result_text in _iterate_previous_outputs(prev_result): - node_input.append(ContentBlock(text=f" - {agent_name}: {result_text}")) + for agent_result in prev_result.get_agent_results(): + agent_name = getattr(agent_result, "agent_name", "Agent") + node_input.append(ContentBlock(text=f" - {agent_name}: {agent_result}")) return node_input def _build_result(self) -> GraphResult: @@ -803,16 +794,25 @@ def _build_result(self) -> GraphResult: def _from_dict(self, payload: dict[str, Any]) -> None: status_raw = payload.get("status", "pending") - try: - self.state.status = Status(status_raw) - except Exception: - self.state.status = Status.PENDING + self.state.status = Status(status_raw) # Hydrate completed nodes & results + raw_results = payload.get("node_results") or {} + results: dict[str, NodeResult] = {} + for node_id, entry in raw_results.items(): + if node_id not in self.nodes: + continue + try: + results[node_id] = NodeResult.from_dict(entry) + except Exception: + logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) + raise + self.state.results = results + + # Restore completed nodes from persisted data completed_node_ids = payload.get("completed_nodes") or [] self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes} - self.state.results = dict(payload.get("node_results") or {}) # Execution order (only nodes that still exist) order_node_ids = payload.get("execution_order") or [] self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes] @@ -860,7 +860,11 @@ def _load_and_apply_persisted_state(self) -> None: continue # only include if it’s dependency-ready incoming = [edge for edge in self.edges if edge.to_node == node] - if any(edge.from_node in completed and edge.should_traverse(self.state) for edge in incoming): + if not incoming: + valid_ready.append(node) + continue + + if all(e.from_node in completed and e.should_traverse(self.state) for e in incoming): valid_ready.append(node) if not valid_ready: @@ -884,18 +888,16 @@ def _map_node_ids(self, node_ids: list[str] | None) -> list[GraphNode]: def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: ready_nodes: list[GraphNode] = [] completed_nodes = set(self.state.completed_nodes) - for node in self.nodes.values(): if node in completed_nodes: continue incoming = [e for e in self.edges if e.to_node is node] - if any(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): + if not incoming: + ready_nodes.append(node) + elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): ready_nodes.append(node) - if ready_nodes: - return ready_nodes - - return [node for node in self.entry_points if node not in completed_nodes] + return ready_nodes def serialize_state(self) -> dict: """Return a JSON-serializable snapshot of the orchestrator state. diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index c697d343c..40da4417a 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -498,6 +498,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st self.session_manager.write_multi_agent_json(self.serialize_state()) except Exception as e: logger.warning("Failed to persist swarm state after handoff: %s", e) + raise def _build_node_input(self, target_node: SwarmNode) -> str: """Build input text for a node based on shared context and handoffs. @@ -768,7 +769,7 @@ def _load_and_apply_persisted_state(self) -> None: saved = self.session_manager.read_multi_agent_json() except Exception as e: logger.warning("Skipping resume; failed to load state: %s", e) - return + raise if not saved: return diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 19c886c1b..3e6bd48e1 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -89,16 +89,6 @@ def test_node_result_get_agent_results(agent_result): assert "Response 1" in response_texts assert "Response 2" in response_texts - # Test with object that has AgentResult class name but isn't AgentResult - from unittest.mock import Mock - - mock_result = Mock() - mock_result.__class__.__name__ = "AgentResult" - node_result = NodeResult(result=mock_result) - agent_results = node_result.get_agent_results() - assert len(agent_results) == 1 - assert agent_results[0] == mock_result - def test_multi_agent_result_initialization(agent_result): """Test MultiAgentResult initialization with defaults and custom values.""" @@ -251,11 +241,9 @@ def test_serialize_node_result_for_persist(agent_result): assert "execution_time" in serialized assert "status" in serialized - # Test with already normalized dict - normalized = {"agent_outputs": ["test1", "test2"]} - serialized = MultiAgentBase.serialize_node_result_for_persist(agent, normalized) - assert serialized == {"agent_outputs": ["test1", "test2"]} + # Test with invalid input type should raise TypeError + with pytest.raises(TypeError, match="serialize_node_result_for_persist expects NodeResult"): + MultiAgentBase.serialize_node_result_for_persist(agent, {"agent_outputs": ["test1", "test2"]}) - # Test fallback case - serialized = MultiAgentBase.serialize_node_result_for_persist(agent, "simple string") - assert serialized == {"agent_outputs": ["simple string"]} + with pytest.raises(TypeError, match="serialize_node_result_for_persist expects NodeResult"): + MultiAgentBase.serialize_node_result_for_persist(agent, "simple string") From d67c8484ed4cc17de98258685e1bf00b16a6f950 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 6 Oct 2025 19:51:59 -0400 Subject: [PATCH 10/27] fix: remove persistence_hooks, use session_manager to subscribe multiagent_hook events, adjust node_result serialization --- .../experimental/multiagent_hooks/__init__.py | 6 +- .../multiagent_hooks/multiagent_events.py | 28 ++++-- src/strands/multiagent/base.py | 11 ++- src/strands/multiagent/graph.py | 86 ++++++++-------- src/strands/multiagent/swarm.py | 99 +++++++++++-------- src/strands/session/session_manager.py | 55 ++++++++--- tests/fixtures/mock_session_repository.py | 13 --- tests/strands/agent/test_agent.py | 23 +++-- .../test_multiagent_events.py | 30 +++--- tests/strands/multiagent/test_base.py | 6 ++ tests/strands/multiagent/test_graph.py | 2 +- tests/strands/multiagent/test_swarm.py | 2 +- .../test_repository_session_manager.py | 9 +- 13 files changed, 210 insertions(+), 160 deletions(-) diff --git a/src/strands/experimental/multiagent_hooks/__init__.py b/src/strands/experimental/multiagent_hooks/__init__.py index a8825f02e..ac8c432a8 100644 --- a/src/strands/experimental/multiagent_hooks/__init__.py +++ b/src/strands/experimental/multiagent_hooks/__init__.py @@ -7,11 +7,13 @@ from .multiagent_events import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - MultiAgentInitializationEvent, + BeforeNodeInvocationEvent, + MultiagentInitializedEvent, ) __all__ = [ "AfterMultiAgentInvocationEvent", - "MultiAgentInitializationEvent", + "MultiagentInitializedEvent", "AfterNodeInvocationEvent", + "BeforeNodeInvocationEvent", ] diff --git a/src/strands/experimental/multiagent_hooks/multiagent_events.py b/src/strands/experimental/multiagent_hooks/multiagent_events.py index 931998cde..1e6496103 100644 --- a/src/strands/experimental/multiagent_hooks/multiagent_events.py +++ b/src/strands/experimental/multiagent_hooks/multiagent_events.py @@ -15,41 +15,53 @@ @dataclass -class MultiAgentInitializationEvent(BaseHookEvent): - """Event triggered when multi-agent orchestrator initializes. +class MultiagentInitializedEvent(BaseHookEvent): + """Event triggered when multi-agent orchestrator initialized. Attributes: - orchestrator: The multi-agent orchestrator instance + source: The multi-agent orchestrator instance invocation_state: Configuration that user pass in """ - orchestrator: "MultiAgentBase" + source: "MultiAgentBase" invocation_state: dict[str, Any] | None = None +@dataclass +class BeforeNodeInvocationEvent(BaseHookEvent): + """Event triggered before individual node execution completes.""" + + pass + + @dataclass class AfterNodeInvocationEvent(BaseHookEvent): """Event triggered after individual node execution completes. Attributes: - orchestrator: The multi-agent orchestrator instance + source: The multi-agent orchestrator instance executed_node: ID of the node that just completed execution invocation_state: Configuration that user pass in """ - orchestrator: "MultiAgentBase" + source: "MultiAgentBase" executed_node: str invocation_state: dict[str, Any] | None = None + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + @dataclass class AfterMultiAgentInvocationEvent(BaseHookEvent): """Event triggered after orchestrator execution completes. Attributes: - orchestrator: The multi-agent orchestrator instance + source: The multi-agent orchestrator instance invocation_state: Configuration that user pass in """ - orchestrator: "MultiAgentBase" + source: "MultiAgentBase" invocation_state: dict[str, Any] | None = None diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 1869c61d2..0c46cfc10 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -105,7 +105,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": return cls( result=result, execution_time=int(data.get("execution_time", 0)), - status=Status(data.get("status", Status.PENDING.value)), + status=Status(data.get("status", "pending")), ) @@ -232,3 +232,12 @@ def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: if not isinstance(raw, NodeResult): raise TypeError(f"serialize_node_result_for_persist expects NodeResult, got {type(raw).__name__}") return raw.to_dict() + + @abstractmethod + def attempt_resume(self, payload: dict[str, Any]) -> None: + """Attempt to resume orchestrator state from a session payload. + + Args: + payload: Session data to restore orchestrator state from + """ + pass diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index cc4e359f1..cc0a7a900 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -29,7 +29,7 @@ from ..experimental.multiagent_hooks import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - MultiAgentInitializationEvent, + MultiagentInitializedEvent, ) from ..hooks import HookRegistry from ..session import SessionManager @@ -413,9 +413,7 @@ def __init__( self._resume_from_persisted = False self._resume_next_nodes: list[GraphNode] = [] - self._load_and_apply_persisted_state() - if not self._resume_from_persisted and self.state.status == Status.PENDING: - self.hooks.invoke_callbacks(MultiAgentInitializationEvent(orchestrator=self), supress_exceptions=True) + self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self), supress_exceptions=True) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -499,7 +497,7 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(orchestrator=self), supress_exceptions=True) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self), supress_exceptions=True) self._resume_from_persisted = False self._resume_next_nodes.clear() return self._build_result() @@ -610,7 +608,7 @@ async def _record_failure(exception: Exception) -> None: self.state.results[node.node_id] = fail_result self.hooks.invoke_callbacks( - AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id), + AfterNodeInvocationEvent(source=self, executed_node=node.node_id), supress_exceptions=True, ) @@ -691,7 +689,7 @@ async def _record_failure(exception: Exception) -> None: "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time ) self.hooks.invoke_callbacks( - AfterNodeInvocationEvent(orchestrator=self, executed_node=node.node_id), supress_exceptions=True + AfterNodeInvocationEvent(source=self, executed_node=node.node_id), supress_exceptions=True ) except asyncio.TimeoutError: @@ -835,46 +833,6 @@ def _to_dict(self) -> dict[str, Any]: "execution_order": [n.node_id for n in self.state.execution_order], } - def _load_and_apply_persisted_state(self) -> None: - if self.session_manager is None: - return - try: - json_data = self.session_manager.read_multi_agent_json() - except Exception as e: - logger.exception("Resume failed; failed to load state: %s", e) - raise - if not json_data: - return - - try: - self.deserialize_state(json_data) - self._resume_from_persisted = True - - next_node_ids = json_data.get("next_node_to_execute") or [] - mapped = self._map_node_ids(next_node_ids) - valid_ready: list[GraphNode] = [] - completed = set(self.state.completed_nodes) - - for node in mapped: - if node in completed or node.execution_status == Status.COMPLETED: - continue - # only include if it’s dependency-ready - incoming = [edge for edge in self.edges if edge.to_node == node] - if not incoming: - valid_ready.append(node) - continue - - if all(e.from_node in completed and e.should_traverse(self.state) for e in incoming): - valid_ready.append(node) - - if not valid_ready: - valid_ready = self._compute_ready_nodes_for_resume() - - self._resume_next_nodes = sorted(valid_ready, key=lambda node: node.node_id) - logger.debug("Resumed from persisted state. Next nodes: %s", [n.node_id for n in self._resume_next_nodes]) - except Exception as e: - logger.exception("Failed to apply multiagent state : %s", e) - def _map_node_ids(self, node_ids: list[str] | None) -> list[GraphNode]: if not node_ids: return [] @@ -914,3 +872,37 @@ def deserialize_state(self, payload: dict) -> None: payload: Dictionary containing persisted state data """ self._from_dict(payload) + + def attempt_resume(self, payload: dict[str, Any]) -> None: + """Apply a persisted graph payload and prepare resume execution.""" + try: + self.deserialize_state(payload) + self._resume_from_persisted = True + + next_node_ids = payload.get("next_node_to_execute") or [] + mapped = self._map_node_ids(next_node_ids) + valid_ready: list[GraphNode] = [] + completed = set(self.state.completed_nodes) + + for node in mapped: + if node in completed or node.execution_status == Status.COMPLETED: + continue + # only include if it’s dependency-ready + incoming = [edge for edge in self.edges if edge.to_node == node] + if not incoming: + valid_ready.append(node) + continue + + if all(e.from_node in completed and e.should_traverse(self.state) for e in incoming): + valid_ready.append(node) + + if not valid_ready: + valid_ready = self._compute_ready_nodes_for_resume() + + self._resume_next_nodes = sorted(valid_ready, key=lambda nodes: nodes.node_id) + logger.debug("Resumed from persisted state. Next nodes: %s", [n.node_id for n in self._resume_next_nodes]) + except Exception: + logger.exception("Failed to apply resume payload") + self._resume_from_persisted = False + self._resume_next_nodes.clear() + raise diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 40da4417a..1fca06317 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -28,7 +28,7 @@ from ..experimental.multiagent_hooks import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - MultiAgentInitializationEvent, + MultiagentInitializedEvent, ) from ..hooks import HookRegistry from ..session import SessionManager @@ -260,9 +260,7 @@ def __init__( self._resume_from_persisted = False self._resume_from_completed = False - self._load_and_apply_persisted_state() - if not self._resume_from_persisted and self.state.completion_status == Status.PENDING: - self.hooks.invoke_callbacks(MultiAgentInitializationEvent(orchestrator=self), supress_exceptions=True) + self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self), supress_exceptions=True) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -341,7 +339,7 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(orchestrator=self), supress_exceptions=True) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self), supress_exceptions=True) self._resume_from_persisted = False self._resume_from_completed = False @@ -651,7 +649,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: self.state.completion_status = Status.FAILED elapsed_time = time.time() - self.state.start_time - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(orchestrator=self), supress_exceptions=True) logger.debug("status=<%s> | swarm execution completed", self.state.completion_status) logger.debug( "node_history_length=<%d>, time=<%s>s | metrics", @@ -762,38 +759,25 @@ def _initial_node(self) -> SwarmNode: return next(iter(self.nodes.values())) # First SwarmNode - def _load_and_apply_persisted_state(self) -> None: - if self.session_manager is None: - return - try: - saved = self.session_manager.read_multi_agent_json() - except Exception as e: - logger.warning("Skipping resume; failed to load state: %s", e) - raise - if not saved: - return - - try: - # saved is a dict - status = saved.get("status") - self.deserialize_state(saved) - - self._resume_from_persisted = True - if status in (Status.COMPLETED.value, "completed"): - self._resume_from_completed = True - # Create a placeholder node to avoid None assignment - placeholder_node = SwarmNode("placeholder", Agent()) - self.state.current_node = placeholder_node - logger.debug("Saved state is COMPLETED; will return persisted result without re-running.") - else: - logger.debug( - "Resumed from persisted state. Current node: %s", - self.state.current_node.node_id if self.state.current_node else "None", - ) - except Exception as e: - logger.exception("Failed to hydrate swarm from persisted state: %s", e) + def attempt_resume(self, payload: dict[str, Any]) -> None: + """Apply persisted state and set resume flags.""" + status = payload.get("status") + self.deserialize_state(payload) + self._resume_from_persisted = True + self._resume_from_completed = status in (Status.COMPLETED.value, "completed") + + if self._resume_from_completed: + # Avoid running again; placeholder so current_node is non-null + placeholder = SwarmNode("placeholder", Agent()) + self.state.current_node = placeholder + logger.debug("Saved state is COMPLETED; will return persisted result without re-running.") + else: + logger.debug( + "Resumed from persisted state. Current node: %s", + self.state.current_node.node_id if self.state.current_node else "None", + ) - def serialize_state(self) -> dict: + def _to_dict(self) -> dict: """Return a JSON-serializable snapshot of the orchestrator state. Returns: @@ -805,13 +789,18 @@ def serialize_state(self) -> dict: if self.state.completion_status == Status.EXECUTING and self.state.current_node else [] ) + # Handoff: 'dict' object is not callable, that's why we need this. + normalized_results: dict[str, NodeResult] = {} + for node_id, entry in (self.state.results or {}).items(): + if isinstance(entry, NodeResult): + normalized_results[node_id] = entry + else: + raise TypeError(f"Unexpected node_result type for {node_id}: {type(entry).__name__}") return { "type": "swarm", "status": status_str, "completed_nodes": [n.node_id for n in self.state.node_history], - "node_results": { - k: self.serialize_node_result_for_persist(v) for k, v in (self.state.results or {}).items() - }, + "node_results": {k: self.serialize_node_result_for_persist(v) for k, v in normalized_results.items()}, "next_node_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.node_history], @@ -821,7 +810,7 @@ def serialize_state(self) -> dict: }, } - def deserialize_state(self, payload: dict) -> None: + def _from_dict(self, payload: dict) -> None: """Restore orchestrator state from a session dict. Args: @@ -842,7 +831,17 @@ def deserialize_state(self, payload: dict) -> None: self.state.node_history = [ self.nodes[nid] for nid in (payload.get("completed_nodes") or []) if nid in self.nodes ] - self.state.results = dict(payload.get("node_results") or {}) + raw_results = payload.get("node_results") or {} + results: dict[str, NodeResult] = {} + for node_id, entry in raw_results.items(): + if node_id not in self.nodes: + continue + try: + results[node_id] = NodeResult.from_dict(entry) + except Exception: + logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id) + raise + self.state.results = results self.state.task = payload.get("current_task", self.state.task) # Determine current node (if executing) @@ -862,3 +861,19 @@ def deserialize_state(self, payload: dict) -> None: except Exception as e: logger.exception("Failed to apply persisted swarm state: %s", e) + + def serialize_state(self) -> dict: + """Return a JSON-serializable snapshot of the orchestrator state. + + Returns: + Dictionary containing the current graph state + """ + return self._to_dict() + + def deserialize_state(self, payload: dict) -> None: + """Restore orchestrator state from a session dict. + + Args: + payload: Dictionary containing persisted state data + """ + self._from_dict(payload) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 03e96d0db..6b1ca0021 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -1,5 +1,6 @@ """Session manager interface for agent session management.""" +import logging import threading from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any @@ -7,7 +8,7 @@ from ..experimental.multiagent_hooks.multiagent_events import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - MultiAgentInitializationEvent, + MultiagentInitializedEvent, ) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry @@ -18,6 +19,8 @@ from ..agent.agent import Agent from ..multiagent.base import MultiAgentBase +logger = logging.getLogger(__name__) + class SessionManager(HookProvider, ABC): """Abstract interface for managing sessions. @@ -53,14 +56,10 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) elif self.session_type == SessionType.MULTI_AGENT: + registry.add_callback(MultiagentInitializedEvent, self._on_multiagent_initialized) + registry.add_callback(AfterNodeInvocationEvent, lambda event: self._persist_multi_agent_state(event.source)) registry.add_callback( - MultiAgentInitializationEvent, lambda event: self._persist_multi_agent_state(event.orchestrator) - ) - registry.add_callback( - AfterNodeInvocationEvent, lambda event: self._persist_multi_agent_state(event.orchestrator) - ) - registry.add_callback( - AfterMultiAgentInvocationEvent, lambda event: self._persist_multi_agent_state(event.orchestrator) + AfterMultiAgentInvocationEvent, lambda event: self._persist_multi_agent_state(event.source) ) @abstractmethod @@ -101,29 +100,57 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: **kwargs: Additional keyword arguments for future extensibility. """ - def _persist_multi_agent_state(self, orchestrator: "MultiAgentBase") -> None: + def _persist_multi_agent_state(self, source: "MultiAgentBase") -> None: """Thread-safe persistence of multi-agent state. Args: - orchestrator: Multi-agent orchestrator to persist + source: Multi-agent orchestrator to persist """ with self._lock: - state = orchestrator.serialize_state() + state = source.serialize_state() self.write_multi_agent_json(state) - # Multiagent abstract functions - @abstractmethod def write_multi_agent_json(self, state: dict[str, Any]) -> None: """Write multi-agent state to persistent storage. Args: state: Multi-agent state dictionary to persist """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(write_multi_agent_json). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) - @abstractmethod def read_multi_agent_json(self) -> dict[str, Any]: """Read multi-agent state from persistent storage. Returns: Multi-agent state dictionary or empty dict if not found """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-agent persistence " + "(read_multi_agent_json). Provide an implementation or use a " + "SessionManager with session_type=SessionType.MULTI_AGENT." + ) + + def _on_multiagent_initialized(self, event: MultiagentInitializedEvent) -> None: + """Initialization path: attempt to resume and then persist a fresh snapshot.""" + source: MultiAgentBase = event.source + try: + payload = self.read_multi_agent_json() + except NotImplementedError: + logger.debug("Multi-agent persistence not implemented; starting fresh") + return + # payload can be {} or Graph/Swarm state json + if payload: + try: + source.attempt_resume(payload) + except Exception: + logger.exception("Failed to apply resume payload; starting fresh") + raise + else: + try: + self._persist_multi_agent_state(source) + except NotImplementedError: + pass diff --git a/tests/fixtures/mock_session_repository.py b/tests/fixtures/mock_session_repository.py index 4d4c93c8e..f3923f68b 100644 --- a/tests/fixtures/mock_session_repository.py +++ b/tests/fixtures/mock_session_repository.py @@ -1,4 +1,3 @@ -from strands.session.repository_session_manager import RepositorySessionManager from strands.session.session_repository import SessionRepository from strands.types.exceptions import SessionException from strands.types.session import SessionAgent, SessionMessage @@ -96,15 +95,3 @@ def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[Sess if limit is not None: return sorted_messages[offset : offset + limit] return sorted_messages[offset:] - - -class TestRepositorySessionManager(RepositorySessionManager): - """Test implementation of RepositorySessionManager with concrete multi-agent methods.""" - - def write_multi_agent_json(self, state: dict) -> None: - """Write multi-agent state (no-op for testing).""" - pass - - def read_multi_agent_json(self) -> dict: - """Read multi-agent state (returns empty dict for testing).""" - return {} diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0e6d6dbcd..2cd87c26d 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -17,12 +17,13 @@ from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel +from strands.session.repository_session_manager import RepositorySessionManager from strands.telemetry.tracer import serialize from strands.types._events import EventLoopStopEvent, ModelStreamEvent from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException, EventLoopException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType -from tests.fixtures.mock_session_repository import MockedSessionRepository, TestRepositorySessionManager +from tests.fixtures.mock_session_repository import MockedSessionRepository from tests.fixtures.mocked_model_provider import MockedModelProvider # For unit testing we will use the the us inference @@ -1528,7 +1529,7 @@ def test_agent_state_get_breaks_deep_dict_reference(): def test_agent_session_management(): mock_session_repository = MockedSessionRepository() - session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) agent = Agent(session_manager=session_manager, model=model) agent("Hello!") @@ -1545,7 +1546,7 @@ def test_agent_restored_from_session_management(): conversation_manager_state=SlidingWindowConversationManager().get_state(), ), ) - session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager) @@ -1566,7 +1567,7 @@ def test_agent_restored_from_session_management_with_message(): mock_session_repository.create_message( "123", "default", SessionMessage({"role": "user", "content": [{"text": "Hello!"}]}, 0) ) - session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager) @@ -1597,9 +1598,7 @@ def test_agent_restored_from_session_management_with_redacted_input(): test_session_id = str(uuid4()) mocked_session_repository = MockedSessionRepository() - session_manager = TestRepositorySessionManager( - session_id=test_session_id, session_repository=mocked_session_repository - ) + session_manager = RepositorySessionManager(session_id=test_session_id, session_repository=mocked_session_repository) agent = Agent( model=mocked_model, @@ -1619,7 +1618,7 @@ def test_agent_restored_from_session_management_with_redacted_input(): assert user_input_session_message.to_message() == agent.messages[0] # Restore an agent from the session, confirm input is still redacted - session_manager_2 = TestRepositorySessionManager( + session_manager_2 = RepositorySessionManager( session_id=test_session_id, session_repository=mocked_session_repository ) agent_2 = Agent( @@ -1638,13 +1637,13 @@ def test_agent_restored_from_session_management_with_correct_index(): [{"role": "assistant", "content": [{"text": "hello!"}]}, {"role": "assistant", "content": [{"text": "world!"}]}] ) mock_session_repository = MockedSessionRepository() - session_manager = TestRepositorySessionManager(session_id="test", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="test", session_repository=mock_session_repository) agent = Agent(session_manager=session_manager, model=mock_model_provider) agent("Hello!") assert len(mock_session_repository.list_messages("test", agent.agent_id)) == 2 - session_manager_2 = TestRepositorySessionManager(session_id="test", session_repository=mock_session_repository) + session_manager_2 = RepositorySessionManager(session_id="test", session_repository=mock_session_repository) agent_2 = Agent(session_manager=session_manager_2, model=mock_model_provider) assert len(agent_2.messages) == 2 @@ -1662,7 +1661,7 @@ def test_agent_restored_from_session_management_with_correct_index(): def test_agent_with_session_and_conversation_manager(): mock_model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello!"}]}]) mock_session_repository = MockedSessionRepository() - session_manager = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) conversation_manager = SlidingWindowConversationManager(window_size=1) # Create an agent with a mocked model and session repository agent = Agent( @@ -1684,7 +1683,7 @@ def test_agent_with_session_and_conversation_manager(): assert len(agent.messages) == 1 # Initialize another agent using the same session - session_manager_2 = TestRepositorySessionManager(session_id="123", session_repository=mock_session_repository) + session_manager_2 = RepositorySessionManager(session_id="123", session_repository=mock_session_repository) conversation_manager_2 = SlidingWindowConversationManager(window_size=1) agent_2 = Agent( session_manager=session_manager_2, diff --git a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py b/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py index 4ecc92d4d..be5e1d5dd 100644 --- a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py +++ b/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py @@ -7,7 +7,7 @@ from strands.experimental.multiagent_hooks.multiagent_events import ( AfterMultiAgentInvocationEvent, AfterNodeInvocationEvent, - MultiAgentInitializationEvent, + MultiagentInitializedEvent, ) from strands.hooks.registry import BaseHookEvent @@ -19,29 +19,29 @@ def orchestrator(): def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): - """Test MultiAgentInitializationEvent creation with orchestrator only.""" - event = MultiAgentInitializationEvent(orchestrator=orchestrator) + """Test MultiagentInitializedEvent creation with orchestrator only.""" + event = MultiagentInitializedEvent(source=orchestrator) - assert event.orchestrator is orchestrator + assert event.source is orchestrator assert event.invocation_state is None assert isinstance(event, BaseHookEvent) def test_multi_agent_initialization_event_with_invocation_state(orchestrator): - """Test MultiAgentInitializationEvent creation with invocation state.""" + """Test MultiagentInitializedEvent creation with invocation state.""" invocation_state = {"key": "value"} - event = MultiAgentInitializationEvent(orchestrator=orchestrator, invocation_state=invocation_state) + event = MultiagentInitializedEvent(source=orchestrator, invocation_state=invocation_state) - assert event.orchestrator is orchestrator + assert event.source is orchestrator assert event.invocation_state == invocation_state def test_after_node_invocation_event_with_required_fields(orchestrator): """Test AfterNodeInvocationEvent creation with required fields.""" executed_node = "node_1" - event = AfterNodeInvocationEvent(orchestrator=orchestrator, executed_node=executed_node) + event = AfterNodeInvocationEvent(source=orchestrator, executed_node=executed_node) - assert event.orchestrator is orchestrator + assert event.source is orchestrator assert event.executed_node == executed_node assert event.invocation_state is None assert isinstance(event, BaseHookEvent) @@ -52,19 +52,19 @@ def test_after_node_invocation_event_with_invocation_state(orchestrator): executed_node = "node_2" invocation_state = {"result": "success"} event = AfterNodeInvocationEvent( - orchestrator=orchestrator, executed_node=executed_node, invocation_state=invocation_state + source=orchestrator, executed_node=executed_node, invocation_state=invocation_state ) - assert event.orchestrator is orchestrator + assert event.source is orchestrator assert event.executed_node == executed_node assert event.invocation_state == invocation_state def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator): """Test AfterMultiAgentInvocationEvent creation with orchestrator only.""" - event = AfterMultiAgentInvocationEvent(orchestrator=orchestrator) + event = AfterMultiAgentInvocationEvent(source=orchestrator) - assert event.orchestrator is orchestrator + assert event.source is orchestrator assert event.invocation_state is None assert isinstance(event, BaseHookEvent) @@ -72,7 +72,7 @@ def test_after_multi_agent_invocation_event_with_orchestrator_only(orchestrator) def test_after_multi_agent_invocation_event_with_invocation_state(orchestrator): """Test AfterMultiAgentInvocationEvent creation with invocation state.""" invocation_state = {"final_state": "completed"} - event = AfterMultiAgentInvocationEvent(orchestrator=orchestrator, invocation_state=invocation_state) + event = AfterMultiAgentInvocationEvent(source=orchestrator, invocation_state=invocation_state) - assert event.orchestrator is orchestrator + assert event.source is orchestrator assert event.invocation_state == invocation_state diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index 3e6bd48e1..18f8860a8 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -152,6 +152,9 @@ def serialize_state(self) -> dict: def deserialize_state(self, payload: dict) -> None: pass + def attempt_resume(self, payload: dict) -> None: + pass + # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) @@ -180,6 +183,9 @@ def serialize_state(self) -> dict: def deserialize_state(self, payload: dict) -> None: pass + def attempt_resume(self, payload: dict) -> None: + pass + agent = TestMultiAgent() # Test with string task diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index fedcf468f..6fc334b7a 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1357,7 +1357,7 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): "next_node_to_execute": ["test_node"], } - graph.deserialize_state(persisted_state) + graph._from_dict(persisted_state) assert graph.state.task == "persisted task" # Execute graph to test persistence integration diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 6fa5edebd..059aff4fb 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -606,7 +606,7 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, } - swarm.deserialize_state(persisted_state) + swarm._from_dict(persisted_state) assert swarm.state.task == "persisted task" assert swarm.state.handoff_message == "test handoff" assert swarm.shared_context.context["test_agent"]["key"] == "value" diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index 77d2e4756..2c25fcc38 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -5,10 +5,11 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager +from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException from strands.types.session import Session, SessionAgent, SessionMessage, SessionType -from tests.fixtures.mock_session_repository import MockedSessionRepository, TestRepositorySessionManager +from tests.fixtures.mock_session_repository import MockedSessionRepository @pytest.fixture @@ -20,7 +21,7 @@ def mock_repository(): @pytest.fixture def session_manager(mock_repository): """Create a session manager with mock repository.""" - return TestRepositorySessionManager(session_id="test-session", session_repository=mock_repository) + return RepositorySessionManager(session_id="test-session", session_repository=mock_repository) @pytest.fixture @@ -35,7 +36,7 @@ def test_init_creates_session_if_not_exists(mock_repository): assert mock_repository.read_session("test-session") is None # Creating manager should create session - TestRepositorySessionManager(session_id="test-session", session_repository=mock_repository) + RepositorySessionManager(session_id="test-session", session_repository=mock_repository) # Verify session created session = mock_repository.read_session("test-session") @@ -51,7 +52,7 @@ def test_init_uses_existing_session(mock_repository): mock_repository.create_session(session) # Creating manager should use existing session - manager = TestRepositorySessionManager(session_id="test-session", session_repository=mock_repository) + manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository) # Verify session used assert manager.session == session From e6b4af2c2f8fd4bdf8cbfe14967dd70832145826 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 7 Oct 2025 10:33:38 -0400 Subject: [PATCH 11/27] fix: remove optional from invoke_callbacks --- src/strands/hooks/registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 9d52a63be..a62d7846c 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,7 +9,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Optional, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar if TYPE_CHECKING: from ..agent import Agent @@ -187,7 +187,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent, supress_exceptions: Optional[bool] = False) -> TInvokeEvent: + def invoke_callbacks(self, event: TInvokeEvent, supress_exceptions: bool = False) -> TInvokeEvent: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and From 1fcfdc062d90896f64cd1f6478a7afe148ca0986 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 7 Oct 2025 10:55:25 -0400 Subject: [PATCH 12/27] fix: fix from_dict consistency --- src/strands/multiagent/base.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 0c46cfc10..8e7a8fffd 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -161,10 +161,30 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": """Rehydrate a MultiAgentResult from persisted JSON.""" + results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} + usage_data = data.get("accumulated_usage", {}) + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + + # Create Metrics with required field + metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) + multiagent_result = cls( status=Status(data.get("status", Status.PENDING.value)), + results=results, + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), + execution_time=int(data.get("execution_time", 0)), ) - multiagent_result.results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} return multiagent_result From 3fe3978ff2a55e57d4ef68d96d1481c820b3b875 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 7 Oct 2025 11:07:22 -0400 Subject: [PATCH 13/27] fix: fix from_dic consistency --- src/strands/multiagent/base.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 8e7a8fffd..859cbfa81 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -83,6 +83,9 @@ def to_dict(self) -> dict[str, Any]: "result": result_data, "execution_time": self.execution_time, "status": self.status.value, + "accumulated_usage": self.accumulated_usage, + "accumulated_metrics": self.accumulated_metrics, + "execution_count": self.execution_count, } @classmethod @@ -102,10 +105,27 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": else: raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") + usage_data = data.get("accumulated_usage", {}) + usage = Usage( + inputTokens=usage_data.get("inputTokens", 0), + outputTokens=usage_data.get("outputTokens", 0), + totalTokens=usage_data.get("totalTokens", 0), + ) + # Add optional fields if they exist + if "cacheReadInputTokens" in usage_data: + usage["cacheReadInputTokens"] = usage_data["cacheReadInputTokens"] + if "cacheWriteInputTokens" in usage_data: + usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] + + metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) + return cls( result=result, execution_time=int(data.get("execution_time", 0)), status=Status(data.get("status", "pending")), + accumulated_usage=usage, + accumulated_metrics=metrics, + execution_count=int(data.get("execution_count", 0)), ) @@ -174,7 +194,6 @@ def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": if "cacheWriteInputTokens" in usage_data: usage["cacheWriteInputTokens"] = usage_data["cacheWriteInputTokens"] - # Create Metrics with required field metrics = Metrics(latencyMs=data.get("accumulated_metrics", {}).get("latencyMs", 0)) multiagent_result = cls( From 683a14fb141dda08815c781ad7b23265b865d8b8 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 7 Oct 2025 12:06:07 -0400 Subject: [PATCH 14/27] fix: fix file session creation issue --- src/strands/session/file_session_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index c2a668204..6affaa844 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -136,8 +136,8 @@ def create_session(self, session: Session, **kwargs: Any) -> Session: raise SessionException(f"Session {session.session_id} already exists") # Create directory structure + os.makedirs(session_dir, exist_ok=True) if self.session_type == SessionType.AGENT: - os.makedirs(session_dir, exist_ok=True) os.makedirs(os.path.join(session_dir, "agents"), exist_ok=True) # Write session file From 7411f2cd9a86d762e2d7e2fa90fda7e065b1b8f4 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 7 Oct 2025 14:44:13 -0400 Subject: [PATCH 15/27] fix: remove completed_nodes, rename execution_order to node_history in Swarm --- src/strands/multiagent/swarm.py | 7 +++---- tests/strands/multiagent/test_swarm.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 1fca06317..56a29958c 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -799,11 +799,10 @@ def _to_dict(self) -> dict: return { "type": "swarm", "status": status_str, - "completed_nodes": [n.node_id for n in self.state.node_history], + "node_history": [n.node_id for n in self.state.node_history], "node_results": {k: self.serialize_node_result_for_persist(v) for k, v in normalized_results.items()}, "next_node_to_execute": next_nodes, "current_task": self.state.task, - "execution_order": [n.node_id for n in self.state.node_history], "context": { "shared_context": getattr(self.state.shared_context, "context", {}) or {}, "handoff_message": self.state.handoff_message, @@ -829,7 +828,7 @@ def _from_dict(self, payload: dict) -> None: # node history and results self.state.node_history = [ - self.nodes[nid] for nid in (payload.get("completed_nodes") or []) if nid in self.nodes + self.nodes[nid] for nid in (payload.get("node_history") or []) if nid in self.nodes ] raw_results = payload.get("node_results") or {} results: dict[str, NodeResult] = {} @@ -852,7 +851,7 @@ def _from_dict(self, payload: dict) -> None: self.state.current_node = found_node if found_node is not None else self._initial_node() else: # fallback to last executed or first node - last = (payload.get("execution_order") or [])[-1:] or [] + last = (payload.get("node_history") or [])[-1:] or [] if last: found_node = self.nodes.get(last[0]) self.state.current_node = found_node if found_node is not None else self._initial_node() diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index 059aff4fb..d2c898ee4 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -591,17 +591,16 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): state = swarm.serialize_state() assert state["type"] == "swarm" assert "status" in state - assert "completed_nodes" in state + assert "node_history" in state assert "node_results" in state assert "context" in state # Test apply_state_from_dict with persisted state persisted_state = { "status": "executing", - "completed_nodes": [], + "node_history": [], "node_results": {}, "current_task": "persisted task", - "execution_order": [], "next_node_to_execute": ["test_agent"], "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, } @@ -622,5 +621,5 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): # Test state serialization after execution final_state = swarm.serialize_state() assert final_state["status"] == "completed" - assert len(final_state["completed_nodes"]) == 1 + assert len(final_state["node_history"]) == 1 assert "test_agent" in final_state["node_results"] From e9c2d57b23041d60578e345e80b07376b0f98316 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 10 Oct 2025 16:24:32 -0400 Subject: [PATCH 16/27] fix: address comments, adding more tests and integration tests in next revision --- .../multiagent_hooks/multiagent_events.py | 18 ++++++++++++++++-- src/strands/hooks/registry.py | 13 ++----------- src/strands/multiagent/base.py | 5 +---- src/strands/multiagent/graph.py | 17 +++++------------ src/strands/multiagent/swarm.py | 12 ++++-------- src/strands/session/file_session_manager.py | 7 +------ 6 files changed, 29 insertions(+), 43 deletions(-) diff --git a/src/strands/experimental/multiagent_hooks/multiagent_events.py b/src/strands/experimental/multiagent_hooks/multiagent_events.py index 1e6496103..e000c7688 100644 --- a/src/strands/experimental/multiagent_hooks/multiagent_events.py +++ b/src/strands/experimental/multiagent_hooks/multiagent_events.py @@ -29,9 +29,10 @@ class MultiagentInitializedEvent(BaseHookEvent): @dataclass class BeforeNodeInvocationEvent(BaseHookEvent): - """Event triggered before individual node execution completes.""" + """Event triggered before individual node execution completes. This event corresponds to the After event.""" - pass + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None @dataclass @@ -54,6 +55,19 @@ def should_reverse_callbacks(self) -> bool: return True +@dataclass +class BeforeMultiAgentInvocationEvent(BaseHookEvent): + """Event triggered after orchestrator execution completes. This event corresponds to the After event. + + Attributes: + source: The multi-agent orchestrator instance + invocation_state: Configuration that user pass in + """ + + source: "MultiAgentBase" + invocation_state: dict[str, Any] | None = None + + @dataclass class AfterMultiAgentInvocationEvent(BaseHookEvent): """Event triggered after orchestrator execution completes. diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index a62d7846c..9181f8f38 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -187,7 +187,7 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) - def invoke_callbacks(self, event: TInvokeEvent, supress_exceptions: bool = False) -> TInvokeEvent: + def invoke_callbacks(self, event: TInvokeEvent) -> TInvokeEvent: """Invoke all registered callbacks for the given event. This method finds all callbacks registered for the event's type and @@ -197,7 +197,6 @@ def invoke_callbacks(self, event: TInvokeEvent, supress_exceptions: bool = False Args: event: The event to dispatch to registered callbacks. - supress_exceptions: Except exception or not. Returns: The event dispatched to registered callbacks. @@ -207,17 +206,9 @@ def invoke_callbacks(self, event: TInvokeEvent, supress_exceptions: bool = False event = StartRequestEvent(agent=my_agent) registry.invoke_callbacks(event) ``` - """ for callback in self.get_callbacks_for(event): - if supress_exceptions: - try: - callback(event) - except Exception as e: - logger.exception("Hook invocation failed for %s: %s", type(event).__name__, e) - pass - else: - callback(event) + callback(event) return event diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 859cbfa81..7fdc5bc71 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -249,12 +249,10 @@ def execute() -> MultiAgentResult: future = executor.submit(execute) return future.result() - @abstractmethod def serialize_state(self) -> dict[str, Any]: """Return a JSON-serializable snapshot of the orchestrator state.""" raise NotImplementedError - @abstractmethod def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError @@ -272,11 +270,10 @@ def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: raise TypeError(f"serialize_node_result_for_persist expects NodeResult, got {type(raw).__name__}") return raw.to_dict() - @abstractmethod def attempt_resume(self, payload: dict[str, Any]) -> None: """Attempt to resume orchestrator state from a session payload. Args: payload: Session data to restore orchestrator state from """ - pass + raise NotImplementedError diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index cc0a7a900..572a881d2 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -413,7 +413,7 @@ def __init__( self._resume_from_persisted = False self._resume_next_nodes: list[GraphNode] = [] - self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self), supress_exceptions=True) + self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -467,9 +467,7 @@ async def invoke_async( else: if isinstance(self.state.task, (str, list)) and not self.state.task: self.state.task = task - # Reset failed nodes after resume. self.state.status = Status.EXECUTING - self.state.failed_nodes.clear() self.state.start_time = time.time() span = self.tracer.start_multiagent_span(task, "graph") with trace_api.use_span(span, end_on_exit=True): @@ -497,7 +495,7 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self), supress_exceptions=True) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self)) self._resume_from_persisted = False self._resume_next_nodes.clear() return self._build_result() @@ -607,16 +605,13 @@ async def _record_failure(exception: Exception) -> None: self.state.failed_nodes.add(node) self.state.results[node.node_id] = fail_result - self.hooks.invoke_callbacks( - AfterNodeInvocationEvent(source=self, executed_node=node.node_id), - supress_exceptions=True, - ) + self.hooks.invoke_callbacks(AfterNodeInvocationEvent(source=self, executed_node=node.node_id)) if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() # Remove from completed nodes since we're re-executing it - self.state.completed_nodes.discard(node) + self.state.completed_nodes.remove(node) node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) @@ -688,9 +683,7 @@ async def _record_failure(exception: Exception) -> None: logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time ) - self.hooks.invoke_callbacks( - AfterNodeInvocationEvent(source=self, executed_node=node.node_id), supress_exceptions=True - ) + self.hooks.invoke_callbacks(AfterNodeInvocationEvent(source=self, executed_node=node.node_id)) except asyncio.TimeoutError: timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 56a29958c..3701485f7 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -260,7 +260,7 @@ def __init__( self._resume_from_persisted = False self._resume_from_completed = False - self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self), supress_exceptions=True) + self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -339,7 +339,7 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self), supress_exceptions=True) + self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self)) self._resume_from_persisted = False self._resume_from_completed = False @@ -619,9 +619,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: logger.debug("node=<%s> | node execution completed", current_node.node_id) - self.hooks.invoke_callbacks( - AfterNodeInvocationEvent(self, executed_node=current_node.node_id), supress_exceptions=True - ) + self.hooks.invoke_callbacks(AfterNodeInvocationEvent(self, executed_node=current_node.node_id)) # Check if the current node is still the same after execution # If it is, then no handoff occurred and we consider the swarm complete @@ -726,9 +724,7 @@ async def _execute_node( self.state.results[node_name] = node_result # Persist failure here - self.hooks.invoke_callbacks( - AfterNodeInvocationEvent(self, executed_node=node_name), supress_exceptions=True - ) + self.hooks.invoke_callbacks(AfterNodeInvocationEvent(self, executed_node=node_name)) raise diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 6affaa844..6a65dcf62 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -122,12 +122,7 @@ def _write_file(self, path: str, data: dict[str, Any]) -> None: json.dump(data, f, indent=2, ensure_ascii=False) f.flush() os.fsync(f.fileno()) - try: - os.replace(tmp, path) - except PermissionError: - # Windows fallback: copy+delete if atomic replace fails - shutil.copy2(tmp, path) - os.remove(tmp) + os.replace(tmp, path) def create_session(self, session: Session, **kwargs: Any) -> Session: """Create a new session.""" From 8233c6274b9d461f53988b390d26089fb1efab71 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 13 Oct 2025 15:11:04 -0400 Subject: [PATCH 17/27] feat: add more unit tests and integration tests to validate Graph/Swarm/Multiagent hooks behavior --- .../fixtures/mock_multiagent_hook_provider.py | 41 ++++++++ .../multiagent_hooks/test_multiagent_hooks.py | 99 +++++++++++++++++++ tests_integ/test_multiagent_graph.py | 85 ++++++++++++++++ tests_integ/test_multiagent_swarm.py | 58 +++++++++++ 4 files changed, 283 insertions(+) create mode 100644 tests/fixtures/mock_multiagent_hook_provider.py create mode 100644 tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py new file mode 100644 index 000000000..cb6c1ceb4 --- /dev/null +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -0,0 +1,41 @@ +from typing import Iterator, Literal, Tuple, Type + +from strands.experimental.multiagent_hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeNodeInvocationEvent, + MultiagentInitializedEvent, +) +from strands.hooks import ( + HookEvent, + HookProvider, + HookRegistry, +) + + +class MockMultiAgentHookProvider(HookProvider): + def __init__(self, event_types: list[Type] | Literal["all"]): + if event_types == "all": + event_types = [ + MultiagentInitializedEvent, + BeforeNodeInvocationEvent, + AfterNodeInvocationEvent, + AfterMultiAgentInvocationEvent, + ] + + self.events_received = [] + self.events_types = event_types + + @property + def event_types_received(self): + return [type(event) for event in self.events_received] + + def get_events(self) -> Tuple[int, Iterator[HookEvent]]: + return len(self.events_received), iter(self.events_received) + + def register_hooks(self, registry: HookRegistry) -> None: + for event_type in self.events_types: + registry.add_callback(event_type, self.add_event) + + def add_event(self, event: HookEvent) -> None: + self.events_received.append(event) diff --git a/tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py b/tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py new file mode 100644 index 000000000..84e1f5a8a --- /dev/null +++ b/tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py @@ -0,0 +1,99 @@ +import pytest + +from strands import Agent +from strands.experimental.multiagent_hooks import ( + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeNodeInvocationEvent, + MultiagentInitializedEvent, +) +from strands.hooks import HookRegistry +from strands.multiagent.graph import Graph, GraphBuilder +from strands.multiagent.swarm import Swarm +from tests.fixtures.mock_multiagent_hook_provider import MockMultiAgentHookProvider +from tests.fixtures.mocked_model_provider import MockedModelProvider + + +@pytest.fixture +def hook_provider(): + return MockMultiAgentHookProvider( + [ + AfterMultiAgentInvocationEvent, + AfterNodeInvocationEvent, + BeforeNodeInvocationEvent, + MultiagentInitializedEvent, + ] + ) + + +@pytest.fixture +def mock_model(): + agent_messages = [ + {"role": "assistant", "content": [{"text": "Task completed"}]}, + {"role": "assistant", "content": [{"text": "Task completed by agent 2"}]}, + {"role": "assistant", "content": [{"text": "Additional response"}]}, + ] + return MockedModelProvider(agent_messages) + + +@pytest.fixture +def agent0(mock_model): + return Agent(model=mock_model, system_prompt="You are a helpful assistant.", name="agent0") + + +@pytest.fixture +def agent1(mock_model): + return Agent(model=mock_model, system_prompt="You are agent 1.", name="agent1") + + +@pytest.fixture +def agent2(mock_model): + return Agent(model=mock_model, system_prompt="You are agent 2.", name="agent2") + + +@pytest.fixture +def swarm(agent1, agent2, hook_provider): + hooks = HookRegistry() + hooks.add_hook(hook_provider) + swarm = Swarm(nodes=[agent1, agent2], hooks=hooks) + return swarm + + +@pytest.fixture +def graph(agent1, agent2, hook_provider): + hooks = HookRegistry() + hooks.add_hook(hook_provider) + builder = GraphBuilder() + builder.add_node(agent1, "agent1") + builder.add_node(agent2, "agent2") + builder.add_edge("agent1", "agent2") + builder.set_entry_point("agent1") + graph = Graph(nodes=builder.nodes, edges=builder.edges, entry_points=builder.entry_points, hooks=hooks) + return graph + + +def test_swarm_complete_hook_lifecycle(swarm, hook_provider): + """E2E test verifying complete hook lifecycle for Swarm.""" + result = swarm("test task") + + length, events = hook_provider.get_events() + assert length == 3 + assert result.status.value == "completed" + + assert next(events) == MultiagentInitializedEvent(source=swarm) + assert next(events) == AfterNodeInvocationEvent(source=swarm, executed_node="agent1") + assert next(events) == AfterMultiAgentInvocationEvent(source=swarm) + + +def test_graph_complete_hook_lifecycle(graph, hook_provider): + """E2E test verifying complete hook lifecycle for Graph.""" + result = graph("test task") + + length, events = hook_provider.get_events() + assert length == 4 + assert result.status.value == "completed" + + assert next(events) == MultiagentInitializedEvent(source=graph) + assert next(events) == AfterNodeInvocationEvent(source=graph, executed_node="agent1") + assert next(events) == AfterNodeInvocationEvent(source=graph, executed_node="agent2") + assert next(events) == AfterMultiAgentInvocationEvent(source=graph) diff --git a/tests_integ/test_multiagent_graph.py b/tests_integ/test_multiagent_graph.py index c2c13c443..b8329b49a 100644 --- a/tests_integ/test_multiagent_graph.py +++ b/tests_integ/test_multiagent_graph.py @@ -1,3 +1,7 @@ +import tempfile +from unittest.mock import patch +from uuid import uuid4 + import pytest from strands import Agent, tool @@ -9,8 +13,11 @@ BeforeModelCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import Status from strands.multiagent.graph import GraphBuilder +from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock +from strands.types.session import SessionType from tests.fixtures.mock_hook_provider import MockHookProvider @@ -83,6 +90,13 @@ def image_analysis_agent(hook_provider): ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield temp_dir + + @pytest.fixture def nested_computation_graph(math_agent, analysis_agent): """Create a nested graph for mathematical computation and analysis.""" @@ -218,3 +232,74 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events + + +@pytest.mark.asyncio +async def test_graph_interrupt_and_resume(): + """Test graph interruption and resume functionality with FileSessionManager.""" + + with tempfile.TemporaryDirectory() as temp_dir: + session_id = str(uuid4()) + + # Create real agents + agent1 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 1", name="agent1") + agent2 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 2", name="agent2") + agent3 = Agent(model="us.amazon.nova-pro-v1:0", system_prompt="You are agent 3", name="agent3") + + session_manager = FileSessionManager( + session_id=session_id, storage_dir=temp_dir, session_type=SessionType.MULTI_AGENT + ) + + builder = GraphBuilder() + builder.add_node(agent1, "node1") + builder.add_node(agent2, "node2") + builder.add_node(agent3, "node3") + builder.add_edge("node1", "node2") + builder.add_edge("node2", "node3") + builder.set_entry_point("node1") + builder.set_session_manager(session_manager) + + graph = builder.build() + + # Mock agent2 to fail on first execution + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure in agent2") + + with patch.object(agent2, "invoke_async", side_effect=failing_invoke): + # First execution - should fail at agent2 + try: + await graph.invoke_async("Test task") + except Exception as e: + assert "Simulated failure in agent2" in str(e) + + # Verify partial execution was persisted + persisted_state = session_manager.read_multi_agent_json() + assert persisted_state is not None + assert persisted_state["type"] == "graph" + assert persisted_state["status"] == "failed" + assert len(persisted_state["completed_nodes"]) == 1 # Only node1 completed + assert "node1" in persisted_state["completed_nodes"] + assert "node2" in persisted_state["next_node_to_execute"] + + # Track execution count before resume + initial_execution_count = graph.state.execution_count + + # Execute graph again + result = await graph.invoke_async("Test task") + + # Verify successful completion + assert result.status == Status.COMPLETED + assert len(result.results) == 3 + + execution_order_ids = [node.node_id for node in result.execution_order] + assert execution_order_ids == ["node1", "node2", "node3"] + + # Verify only 2 additional nodes were executed + assert result.execution_count == initial_execution_count + 2 + + final_state = session_manager.read_multi_agent_json() + assert final_state["status"] == "completed" + assert len(final_state["completed_nodes"]) == 3 + + # Clean up + session_manager.delete_session(session_id) diff --git a/tests_integ/test_multiagent_swarm.py b/tests_integ/test_multiagent_swarm.py index 9a8c79bf8..a80b97f40 100644 --- a/tests_integ/test_multiagent_swarm.py +++ b/tests_integ/test_multiagent_swarm.py @@ -1,3 +1,7 @@ +import tempfile +from unittest.mock import patch +from uuid import uuid4 + import pytest from strands import Agent, tool @@ -10,8 +14,11 @@ BeforeToolCallEvent, MessageAddedEvent, ) +from strands.multiagent.base import Status from strands.multiagent.swarm import Swarm +from strands.session.file_session_manager import FileSessionManager from strands.types.content import ContentBlock +from strands.types.session import SessionType from tests.fixtures.mock_hook_provider import MockHookProvider @@ -134,3 +141,54 @@ async def test_swarm_execution_with_image(researcher_agent, analyst_agent, write # Verify agent history - at least one agent should have been used assert len(result.node_history) > 0 + + +@pytest.mark.asyncio +async def test_swarm_interrupt_and_resume(researcher_agent, analyst_agent, writer_agent): + """Test swarm interruption after analyst_agent and resume functionality.""" + with tempfile.TemporaryDirectory() as temp_dir: + session_id = str(uuid4()) + + # Create session manager + session_manager = FileSessionManager( + session_id=session_id, storage_dir=temp_dir, session_type=SessionType.MULTI_AGENT + ) + + # Create swarm with session manager + swarm = Swarm([researcher_agent, analyst_agent, writer_agent], session_manager=session_manager) + + # Mock analyst_agent to fail + async def failing_invoke(*args, **kwargs): + raise Exception("Simulated failure in analyst") + + with patch.object(analyst_agent, "invoke_async", side_effect=failing_invoke): + # First execution - should fail at analyst + result = await swarm.invoke_async("Research AI trends and create a brief report") + assert result.status == Status.FAILED + + # Verify partial execution was persisted + persisted_state = session_manager.read_multi_agent_json() + assert persisted_state is not None + assert persisted_state["type"] == "swarm" + assert persisted_state["status"] == "failed" + assert len(persisted_state["node_history"]) == 1 # At least researcher executed + + # Track execution count before resume + initial_execution_count = len(persisted_state["node_history"]) + + # Execute swarm again - should automatically resume from saved state + result = await swarm.invoke_async("Research AI trends and create a brief report") + + # Verify successful completion + assert result.status == Status.COMPLETED + assert len(result.results) > 0 + + assert len(result.node_history) >= initial_execution_count + 1 + + node_names = [node.node_id for node in result.node_history] + assert "researcher" in node_names + # Either analyst or writer (or both) should have executed to complete the task + assert "analyst" in node_names or "writer" in node_names + + # Clean up + session_manager.delete_session(session_id) From f1aac16414cadcc61064df65b0c64350a43e374c Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 13 Oct 2025 15:56:31 -0400 Subject: [PATCH 18/27] fix: fix bad rebase --- src/strands/hooks/registry.py | 3 --- src/strands/multiagent/graph.py | 4 ++-- src/strands/session/s3_session_manager.py | 15 --------------- 3 files changed, 2 insertions(+), 20 deletions(-) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 9181f8f38..b8e7f82ab 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,15 +7,12 @@ via hook provider objects. """ -import logging from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar if TYPE_CHECKING: from ..agent import Agent -logger = logging.getLogger(__name__) - @dataclass class BaseHookEvent: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 572a881d2..accfe95ab 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -644,11 +644,11 @@ async def _record_failure(exception: Exception) -> None: elif isinstance(node.executor, Agent): if self.node_timeout is not None: agent_response = await asyncio.wait_for( - node.executor.invoke_async(node_input, **invocation_state), + node.executor.invoke_async(node_input, invocation_state=invocation_state), timeout=self.node_timeout, ) else: - agent_response = await node.executor.invoke_async(node_input, **invocation_state) + agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index 25d9c9f06..b6e18f2a5 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -303,21 +303,6 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e - async def _load_messages_concurrently(self, message_keys: List[str]) -> List[SessionMessage]: - """Load multiple message objects concurrently using async.""" - if not message_keys: - return [] - - async def load_message(key: str) -> Optional[SessionMessage]: - loop = asyncio.get_event_loop() - message_data = await loop.run_in_executor(None, self._read_s3_object, key) - return SessionMessage.from_dict(message_data) if message_data else None - - tasks = [load_message(key) for key in message_keys] - loaded_messages = await asyncio.gather(*tasks) - - return [msg for msg in loaded_messages if msg is not None] - def write_multi_agent_json(self, state: dict[str, Any]) -> None: """Write multi-agent state to S3. From 7735ed3bb9d0945614ee2357caad7b8646ca5371 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 15 Oct 2025 16:16:00 -0400 Subject: [PATCH 19/27] fix: revert single agent session_manager validator --- src/strands/multiagent/graph.py | 6 ++++++ tests/strands/multiagent/test_graph.py | 25 ++++++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index accfe95ab..1228b0bcc 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -202,6 +202,12 @@ def _validate_node_executor( if id(executor) in seen_instances: raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.") + # Validate Agent-specific constraints + if isinstance(executor, Agent): + # Check for session persistence + if executor._session_manager is not None: + raise ValueError("Session persistence is not supported for Graph agents yet.") + class GraphBuilder: """Builder pattern for constructing graphs.""" diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index da90a5dfd..a31990b4f 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -6,7 +6,8 @@ from strands.agent import Agent, AgentResult from strands.agent.state import AgentState -from strands.hooks.registry import HookRegistry +from strands.hooks import AgentInitializedEvent +from strands.hooks.registry import HookProvider, HookRegistry from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult from strands.multiagent.graph import Graph, GraphBuilder, GraphEdge, GraphNode, GraphResult, GraphState, Status from strands.session.session_manager import SessionManager @@ -864,16 +865,30 @@ def test_graph_validate_unsupported_features(): graph = builder.build() assert len(graph.nodes) == 1 - # Test with session manager (should work now - session persistence is supported) + # Test with session manager (should fail in GraphBuilder.add_node) mock_session_manager = Mock(spec=SessionManager) agent_with_session = create_mock_agent("agent_with_session") agent_with_session._session_manager = mock_session_manager agent_with_session.hooks = HookRegistry() builder = GraphBuilder() - builder.add_node(agent_with_session) - graph = builder.build() - assert len(graph.nodes) == 1 + with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): + builder.add_node(agent_with_session) + + # Test with callbacks (should fail in GraphBuilder.add_node) + class TestHookProvider(HookProvider): + def register_hooks(self, registry, **kwargs): + registry.add_callback(AgentInitializedEvent, lambda e: None) + + # Test validation in Graph constructor (when nodes are passed directly) + # Test with session manager in Graph constructor + node_with_session = GraphNode("node_with_session", agent_with_session) + with pytest.raises(ValueError, match="Session persistence is not supported for Graph agents yet"): + Graph( + nodes={"node_with_session": node_with_session}, + edges=set(), + entry_points=set(), + ) @pytest.mark.asyncio From 7b3aabb83724c3a6e27861c19ba1442ea7758fa5 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 17 Oct 2025 15:25:42 -0400 Subject: [PATCH 20/27] fix: refine code structures, address related comments --- .../{ => hooks}/multiagent_hooks/__init__.py | 14 +-- .../multiagent_hooks/multiagent_events.py | 23 ++-- src/strands/multiagent/base.py | 54 ++++------ src/strands/multiagent/graph.py | 102 ++++++++++-------- src/strands/multiagent/swarm.py | 69 ++++++------ src/strands/session/session_manager.py | 24 ++--- .../fixtures/mock_multiagent_hook_provider.py | 14 +-- .../{ => hooks}/multiagent_hooks/__init__.py | 0 .../test_multiagent_events.py | 32 +++--- .../multiagent_hooks/test_multiagent_hooks.py | 33 +++--- tests/strands/multiagent/test_base.py | 12 +-- 11 files changed, 187 insertions(+), 190 deletions(-) rename src/strands/experimental/{ => hooks}/multiagent_hooks/__init__.py (58%) rename src/strands/experimental/{ => hooks}/multiagent_hooks/multiagent_events.py (78%) rename tests/strands/experimental/{ => hooks}/multiagent_hooks/__init__.py (100%) rename tests/strands/experimental/{ => hooks}/multiagent_hooks/test_multiagent_events.py (68%) rename tests/strands/experimental/{ => hooks}/multiagent_hooks/test_multiagent_hooks.py (71%) diff --git a/src/strands/experimental/multiagent_hooks/__init__.py b/src/strands/experimental/hooks/multiagent_hooks/__init__.py similarity index 58% rename from src/strands/experimental/multiagent_hooks/__init__.py rename to src/strands/experimental/hooks/multiagent_hooks/__init__.py index ac8c432a8..c3dc793e7 100644 --- a/src/strands/experimental/multiagent_hooks/__init__.py +++ b/src/strands/experimental/hooks/multiagent_hooks/__init__.py @@ -6,14 +6,16 @@ from .multiagent_events import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - BeforeNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, ) __all__ = [ "AfterMultiAgentInvocationEvent", - "MultiagentInitializedEvent", - "AfterNodeInvocationEvent", - "BeforeNodeInvocationEvent", + "AfterNodeCallEvent", + "BeforeMultiAgentInvocationEvent", + "BeforeNodeCallEvent", + "MultiAgentInitializedEvent", ] diff --git a/src/strands/experimental/multiagent_hooks/multiagent_events.py b/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py similarity index 78% rename from src/strands/experimental/multiagent_hooks/multiagent_events.py rename to src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py index e000c7688..ce0ddbd95 100644 --- a/src/strands/experimental/multiagent_hooks/multiagent_events.py +++ b/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py @@ -8,14 +8,14 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from ...hooks.registry import BaseHookEvent +from ....hooks import BaseHookEvent if TYPE_CHECKING: - from ...multiagent.base import MultiAgentBase + from ....multiagent.base import MultiAgentBase @dataclass -class MultiagentInitializedEvent(BaseHookEvent): +class MultiAgentInitializedEvent(BaseHookEvent): """Event triggered when multi-agent orchestrator initialized. Attributes: @@ -28,25 +28,32 @@ class MultiagentInitializedEvent(BaseHookEvent): @dataclass -class BeforeNodeInvocationEvent(BaseHookEvent): - """Event triggered before individual node execution completes. This event corresponds to the After event.""" +class BeforeNodeCallEvent(BaseHookEvent): + """Event triggered before individual node execution completes. This event corresponds to the After event. + + Attributes: + source: The multi-agent orchestrator instance + node_id: ID of the node that just completed execution + invocation_state: Configuration that user pass in + """ source: "MultiAgentBase" + node_id: str invocation_state: dict[str, Any] | None = None @dataclass -class AfterNodeInvocationEvent(BaseHookEvent): +class AfterNodeCallEvent(BaseHookEvent): """Event triggered after individual node execution completes. Attributes: source: The multi-agent orchestrator instance - executed_node: ID of the node that just completed execution + node_id: ID of the node that just completed execution invocation_state: Configuration that user pass in """ source: "MultiAgentBase" - executed_node: str + node_id: str invocation_state: dict[str, Any] | None = None @property diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 07abbea1d..fcc92dc44 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -10,12 +10,12 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Any, Literal, Union, cast +from typing import Any, Union, cast from ..agent import AgentResult from ..telemetry.metrics import EventLoopMetrics from ..types.content import ContentBlock, Message -from ..types.event_loop import Metrics, Usage +from ..types.event_loop import Metrics, StopReason, Usage logger = logging.getLogger(__name__) @@ -56,7 +56,6 @@ def get_agent_results(self) -> list[AgentResult]: return [] # No agent results for exceptions elif isinstance(self.result, AgentResult): return [self.result] - # If this is a nested MultiAgentResult, flatten children else: # Flatten nested results from MultiAgentResult flattened = [] @@ -98,7 +97,7 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": result: Union[AgentResult, "MultiAgentResult", Exception] if isinstance(raw, dict) and raw.get("type") == "agent_result": - result = _agent_result_from_persisted(raw) + result = NodeResult.agent_result_from_persisted(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": result = Exception(str(raw.get("message", "node failed"))) elif isinstance(raw, dict) and ("results" in raw): @@ -129,27 +128,27 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": execution_count=int(data.get("execution_count", 0)), ) + @classmethod + def agent_result_from_persisted(cls, data: dict[str, Any]) -> AgentResult: + """Rehydrate a minimal AgentResult from persisted JSON. -def _agent_result_from_persisted(data: dict[str, Any]) -> AgentResult: - """Rehydrate a minimal AgentResult from persisted JSON. - - Expected shape: - {"type": "agent_result", "message": , "stop_reason": } - """ - if data.get("type") != "agent_result": - raise TypeError(f"_agent_result_from_persisted: unexpected type {data.get('type')!r}") + Expected shape: + {"type": "agent_result", "message": , "stop_reason": } + """ + if data.get("type") != "agent_result": + raise TypeError(f"agent_result_from_persisted: unexpected type {data.get('type')!r}") - message = cast(Message, data.get("message")) - stop_reason = cast( - Literal["content_filtered", "end_turn", "guardrail_intervened", "max_tokens", "stop_sequence", "tool_use"], - data.get("stop_reason"), - ) + message = cast(Message, data.get("message")) + stop_reason = cast( + StopReason, + data.get("stop_reason"), + ) - try: - return AgentResult(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) - except Exception: - logger.debug("AgentResult constructor failed during rehydrating") - raise + try: + return AgentResult(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + except Exception: + logger.debug("AgentResult constructor failed during rehydrating") + raise @dataclass @@ -171,6 +170,7 @@ class MultiAgentResult: def to_dict(self) -> dict[str, Any]: """Convert MultiAgentResult to JSON-serializable dict.""" return { + "type": "mutiagent_result", "status": self.status.value, "results": {k: v.to_dict() for k, v in self.results.items()}, "accumulated_usage": dict(self.accumulated_usage), @@ -271,14 +271,4 @@ def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: Returns: JSON-serializable dict representation """ - if not isinstance(raw, NodeResult): - raise TypeError(f"serialize_node_result_for_persist expects NodeResult, got {type(raw).__name__}") return raw.to_dict() - - def attempt_resume(self, payload: dict[str, Any]) -> None: - """Attempt to resume orchestrator state from a session payload. - - Args: - payload: Session data to restore orchestrator state from - """ - raise NotImplementedError diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 1228b0bcc..1387621fe 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -26,12 +26,12 @@ from ..agent import Agent from ..agent.state import AgentState -from ..experimental.multiagent_hooks import ( +from ..experimental.hooks.multiagent_hooks import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, ) -from ..hooks import HookRegistry +from ..hooks import HookProvider, HookRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..types.content import ContentBlock, Messages @@ -226,6 +226,7 @@ def __init__(self) -> None: # session manager self._session_manager: Optional[SessionManager] = None + self._hooks: Optional[list[HookProvider]] = None def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode: """Add an Agent or MultiAgentBase instance as a node to the graph.""" @@ -325,6 +326,15 @@ def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder" self._session_manager = session_manager return self + def set_hook_provider(self, hook_providers: list[HookProvider]) -> "GraphBuilder": + """Set hook provider for the graph. + + Args: + hook_providers: SessionManager instance + """ + self._hooks = hook_providers + return self + def build(self) -> "Graph": """Build and validate the graph with configured settings.""" if not self.nodes: @@ -351,6 +361,7 @@ def build(self) -> "Graph": node_timeout=self._node_timeout, reset_on_revisit=self._reset_on_revisit, session_manager=self._session_manager, + hooks=self._hooks, ) def _validate_graph(self) -> None: @@ -379,7 +390,7 @@ def __init__( node_timeout: Optional[float] = None, reset_on_revisit: bool = False, session_manager: Optional[SessionManager] = None, - hooks: Optional[HookRegistry] = None, + hooks: Optional[list[HookProvider]] = None, ) -> None: """Initialize Graph with execution limits and reset behavior. @@ -409,17 +420,17 @@ def __init__( self.state = GraphState() self.tracer = get_tracer() self.session_manager = session_manager - self.hooks = hooks or HookRegistry() + self.hooks = HookRegistry() if self.session_manager is not None: self.hooks.add_hook(self.session_manager) - - # Concurrent lock - self._lock = asyncio.Lock() - # Resume flag + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + # Resume flags self._resume_from_persisted = False self._resume_next_nodes: list[GraphNode] = [] - self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self)) + self.hooks.invoke_callbacks(MultiAgentInitializedEvent(source=self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -467,7 +478,7 @@ async def invoke_async( task=task, total_nodes=len(self.nodes), edges=[(edge.from_node, edge.to_node) for edge in self.edges], - entry_points=sorted(self.entry_points, key=lambda node: node.node_id), + entry_points=list(self.entry_points), start_time=start_time, ) else: @@ -520,11 +531,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" - ready_nodes = ( - sorted(self._resume_next_nodes, key=lambda n: n.node_id) - if self._resume_from_persisted - else sorted(self.entry_points, key=lambda n: n.node_id) - ) + ready_nodes = self._resume_next_nodes if self._resume_from_persisted else list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -604,14 +611,13 @@ async def _record_failure(exception: Exception) -> None: accumulated_metrics=Metrics(latencyMs=execution_time), execution_count=1, ) - async with self._lock: - node.execution_status = Status.FAILED - node.result = fail_result - node.execution_time = execution_time - self.state.failed_nodes.add(node) - self.state.results[node.node_id] = fail_result - self.hooks.invoke_callbacks(AfterNodeInvocationEvent(source=self, executed_node=node.node_id)) + node.execution_status = Status.FAILED + node.result = fail_result + node.execution_time = execution_time + self.state.failed_nodes.add(node) + self.state.results[node.node_id] = fail_result + self.hooks.invoke_callbacks(AfterNodeCallEvent(source=self, node_id=node.node_id)) if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) @@ -676,20 +682,19 @@ async def _record_failure(exception: Exception) -> None: else: raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported") - async with self._lock: - node.execution_status = Status.COMPLETED - node.result = node_result - node.execution_time = node_result.execution_time - self.state.completed_nodes.add(node) - self.state.results[node.node_id] = node_result - self.state.execution_order.append(node) - # Accumulate metrics - self._accumulate_metrics(node_result) + node.execution_status = Status.COMPLETED + node.result = node_result + node.execution_time = node_result.execution_time + self.state.completed_nodes.add(node) + self.state.results[node.node_id] = node_result + self.state.execution_order.append(node) + # Accumulate metrics + self._accumulate_metrics(node_result) + self.hooks.invoke_callbacks(AfterNodeCallEvent(source=self, node_id=node.node_id)) logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time ) - self.hooks.invoke_callbacks(AfterNodeInvocationEvent(source=self, executed_node=node.node_id)) except asyncio.TimeoutError: timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s" @@ -865,17 +870,29 @@ def serialize_state(self) -> dict: return self._to_dict() def deserialize_state(self, payload: dict) -> None: - """Restore orchestrator state from a session dict. + """Restore orchestrator state from a session dict and prepare for execution. + + This method handles two scenarios: + 1. If the persisted status is COMPLETED, resets all nodes and graph state + to allow re-execution from the beginning. + 2. Otherwise, restores the persisted state and prepares to resume execution + from the next ready nodes. Args: - payload: Dictionary containing persisted state data + payload: Dictionary containing persisted state data including status, + completed nodes, results, and next nodes to execute. """ - self._from_dict(payload) + if payload.get("status") in (Status.COMPLETED.value, "completed"): + # Reset all nodes + for node in self.nodes.values(): + node.reset_executor_state() + # Reset graph state + self.state = GraphState() + self._resume_from_persisted = False + return - def attempt_resume(self, payload: dict[str, Any]) -> None: - """Apply a persisted graph payload and prepare resume execution.""" try: - self.deserialize_state(payload) + self._from_dict(payload) self._resume_from_persisted = True next_node_ids = payload.get("next_node_to_execute") or [] @@ -886,7 +903,7 @@ def attempt_resume(self, payload: dict[str, Any]) -> None: for node in mapped: if node in completed or node.execution_status == Status.COMPLETED: continue - # only include if it’s dependency-ready + # only include if it's dependency-ready incoming = [edge for edge in self.edges if edge.to_node == node] if not incoming: valid_ready.append(node) @@ -897,8 +914,7 @@ def attempt_resume(self, payload: dict[str, Any]) -> None: if not valid_ready: valid_ready = self._compute_ready_nodes_for_resume() - - self._resume_next_nodes = sorted(valid_ready, key=lambda nodes: nodes.node_id) + self._resume_next_nodes = valid_ready logger.debug("Resumed from persisted state. Next nodes: %s", [n.node_id for n in self._resume_next_nodes]) except Exception: logger.exception("Failed to apply resume payload") diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 1d09cf1f6..46648cddf 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -19,18 +19,18 @@ import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple from opentelemetry import trace as trace_api from ..agent import Agent, AgentResult from ..agent.state import AgentState -from ..experimental.multiagent_hooks import ( +from ..experimental.hooks.multiagent_hooks import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, ) -from ..hooks import HookRegistry +from ..hooks import HookProvider, HookRegistry from ..session import SessionManager from ..telemetry import get_tracer from ..tools.decorator import tool @@ -210,8 +210,8 @@ def __init__( node_timeout: float = 300.0, repetitive_handoff_detection_window: int = 0, repetitive_handoff_min_unique_agents: int = 0, - session_manager: SessionManager | None = None, - hooks: HookRegistry | None = None, + session_manager: Optional[SessionManager] = None, + hooks: Optional[list[HookProvider]] = None, ) -> None: """Initialize Swarm with agents and configuration. @@ -249,18 +249,19 @@ def __init__( self.tracer = get_tracer() self.session_manager = session_manager - self.hooks = hooks or HookRegistry() + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) if self.session_manager is not None: self.hooks.add_hook(self.session_manager) self._setup_swarm(nodes) self._inject_swarm_tools() - # We need flags here from Graph since they have different mechanism to determine end of loop. self._resume_from_persisted = False - self._resume_from_completed = False - self.hooks.invoke_callbacks(MultiagentInitializedEvent(source=self)) + self.hooks.invoke_callbacks(MultiAgentInitializedEvent(source=self)) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -300,10 +301,6 @@ async def invoke_async( logger.debug("starting swarm execution") - if self._resume_from_persisted and self._resume_from_completed: - logger.debug("Returning persisted COMPLETED result without re-execution.") - return self._build_result() - # If resume if not self._resume_from_persisted: initial_node = self._initial_node() @@ -619,7 +616,7 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: logger.debug("node=<%s> | node execution completed", current_node.node_id) - self.hooks.invoke_callbacks(AfterNodeInvocationEvent(self, executed_node=current_node.node_id)) + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=current_node.node_id)) # Check if the current node is still the same after execution # If it is, then no handoff occurred and we consider the swarm complete @@ -723,7 +720,7 @@ async def _execute_node( self.state.results[node_name] = node_result # Persist failure here - self.hooks.invoke_callbacks(AfterNodeInvocationEvent(self, executed_node=node_name)) + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=node_name)) raise @@ -754,24 +751,6 @@ def _initial_node(self) -> SwarmNode: return next(iter(self.nodes.values())) # First SwarmNode - def attempt_resume(self, payload: dict[str, Any]) -> None: - """Apply persisted state and set resume flags.""" - status = payload.get("status") - self.deserialize_state(payload) - self._resume_from_persisted = True - self._resume_from_completed = status in (Status.COMPLETED.value, "completed") - - if self._resume_from_completed: - # Avoid running again; placeholder so current_node is non-null - placeholder = SwarmNode("placeholder", Agent()) - self.state.current_node = placeholder - logger.debug("Saved state is COMPLETED; will return persisted result without re-running.") - else: - logger.debug( - "Resumed from persisted state. Current node: %s", - self.state.current_node.node_id if self.state.current_node else "None", - ) - def _to_dict(self) -> dict: """Return a JSON-serializable snapshot of the orchestrator state. @@ -870,4 +849,22 @@ def deserialize_state(self, payload: dict) -> None: Args: payload: Dictionary containing persisted state data """ - self._from_dict(payload) + if payload.get("status") in (Status.COMPLETED.value, "completed"): + # Reset all nodes + for node in self.nodes.values(): + node.reset_executor_state() + # Reset graph state + self.state = SwarmState( + current_node=SwarmNode("", Agent()), # Placeholder, will be set properly + task="", + completion_status=Status.PENDING, + ) + self._resume_from_persisted = False + return + else: + self._from_dict(payload) + self._resume_from_persisted = True + logger.debug( + "Resumed from persisted state. Current node: %s", + self.state.current_node.node_id if self.state.current_node else "None", + ) diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 6b1ca0021..c54175062 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -5,10 +5,10 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any -from ..experimental.multiagent_hooks.multiagent_events import ( +from ..experimental.hooks.multiagent_hooks.multiagent_events import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, ) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent from ..hooks.registry import HookProvider, HookRegistry @@ -56,8 +56,8 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) elif self.session_type == SessionType.MULTI_AGENT: - registry.add_callback(MultiagentInitializedEvent, self._on_multiagent_initialized) - registry.add_callback(AfterNodeInvocationEvent, lambda event: self._persist_multi_agent_state(event.source)) + registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized) + registry.add_callback(AfterNodeCallEvent, lambda event: self._persist_multi_agent_state(event.source)) registry.add_callback( AfterMultiAgentInvocationEvent, lambda event: self._persist_multi_agent_state(event.source) ) @@ -134,21 +134,13 @@ def read_multi_agent_json(self) -> dict[str, Any]: "SessionManager with session_type=SessionType.MULTI_AGENT." ) - def _on_multiagent_initialized(self, event: MultiagentInitializedEvent) -> None: + def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: """Initialization path: attempt to resume and then persist a fresh snapshot.""" source: MultiAgentBase = event.source - try: - payload = self.read_multi_agent_json() - except NotImplementedError: - logger.debug("Multi-agent persistence not implemented; starting fresh") - return + payload = self.read_multi_agent_json() # payload can be {} or Graph/Swarm state json if payload: - try: - source.attempt_resume(payload) - except Exception: - logger.exception("Failed to apply resume payload; starting fresh") - raise + source.deserialize_state(payload) else: try: self._persist_multi_agent_state(source) diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index cb6c1ceb4..12679e91d 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,10 +1,10 @@ from typing import Iterator, Literal, Tuple, Type -from strands.experimental.multiagent_hooks import ( +from strands.experimental.hooks.multiagent_hooks import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - BeforeNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, ) from strands.hooks import ( HookEvent, @@ -17,9 +17,9 @@ class MockMultiAgentHookProvider(HookProvider): def __init__(self, event_types: list[Type] | Literal["all"]): if event_types == "all": event_types = [ - MultiagentInitializedEvent, - BeforeNodeInvocationEvent, - AfterNodeInvocationEvent, + MultiAgentInitializedEvent, + BeforeNodeCallEvent, + AfterNodeCallEvent, AfterMultiAgentInvocationEvent, ] diff --git a/tests/strands/experimental/multiagent_hooks/__init__.py b/tests/strands/experimental/hooks/multiagent_hooks/__init__.py similarity index 100% rename from tests/strands/experimental/multiagent_hooks/__init__.py rename to tests/strands/experimental/hooks/multiagent_hooks/__init__.py diff --git a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_events.py similarity index 68% rename from tests/strands/experimental/multiagent_hooks/test_multiagent_events.py rename to tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_events.py index be5e1d5dd..d7f91d198 100644 --- a/tests/strands/experimental/multiagent_hooks/test_multiagent_events.py +++ b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_events.py @@ -4,10 +4,10 @@ import pytest -from strands.experimental.multiagent_hooks.multiagent_events import ( +from strands.experimental.hooks.multiagent_hooks.multiagent_events import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + MultiAgentInitializedEvent, ) from strands.hooks.registry import BaseHookEvent @@ -19,8 +19,8 @@ def orchestrator(): def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): - """Test MultiagentInitializedEvent creation with orchestrator only.""" - event = MultiagentInitializedEvent(source=orchestrator) + """Test MultiAgentInitializedEvent creation with orchestrator only.""" + event = MultiAgentInitializedEvent(source=orchestrator) assert event.source is orchestrator assert event.invocation_state is None @@ -28,35 +28,33 @@ def test_multi_agent_initialization_event_with_orchestrator_only(orchestrator): def test_multi_agent_initialization_event_with_invocation_state(orchestrator): - """Test MultiagentInitializedEvent creation with invocation state.""" + """Test MultiAgentInitializedEvent creation with invocation state.""" invocation_state = {"key": "value"} - event = MultiagentInitializedEvent(source=orchestrator, invocation_state=invocation_state) + event = MultiAgentInitializedEvent(source=orchestrator, invocation_state=invocation_state) assert event.source is orchestrator assert event.invocation_state == invocation_state def test_after_node_invocation_event_with_required_fields(orchestrator): - """Test AfterNodeInvocationEvent creation with required fields.""" - executed_node = "node_1" - event = AfterNodeInvocationEvent(source=orchestrator, executed_node=executed_node) + """Test AfterNodeCallEvent creation with required fields.""" + node_id = "node_1" + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id) assert event.source is orchestrator - assert event.executed_node == executed_node + assert event.node_id == node_id assert event.invocation_state is None assert isinstance(event, BaseHookEvent) def test_after_node_invocation_event_with_invocation_state(orchestrator): - """Test AfterNodeInvocationEvent creation with invocation state.""" - executed_node = "node_2" + """Test AfterNodeCallEvent creation with invocation state.""" + node_id = "node_2" invocation_state = {"result": "success"} - event = AfterNodeInvocationEvent( - source=orchestrator, executed_node=executed_node, invocation_state=invocation_state - ) + event = AfterNodeCallEvent(source=orchestrator, node_id=node_id, invocation_state=invocation_state) assert event.source is orchestrator - assert event.executed_node == executed_node + assert event.node_id == node_id assert event.invocation_state == invocation_state diff --git a/tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py similarity index 71% rename from tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py rename to tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py index 84e1f5a8a..49653296b 100644 --- a/tests/strands/experimental/multiagent_hooks/test_multiagent_hooks.py +++ b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py @@ -1,13 +1,12 @@ import pytest from strands import Agent -from strands.experimental.multiagent_hooks import ( +from strands.experimental.hooks.multiagent_hooks import ( AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - BeforeNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, ) -from strands.hooks import HookRegistry from strands.multiagent.graph import Graph, GraphBuilder from strands.multiagent.swarm import Swarm from tests.fixtures.mock_multiagent_hook_provider import MockMultiAgentHookProvider @@ -19,9 +18,9 @@ def hook_provider(): return MockMultiAgentHookProvider( [ AfterMultiAgentInvocationEvent, - AfterNodeInvocationEvent, - BeforeNodeInvocationEvent, - MultiagentInitializedEvent, + AfterNodeCallEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, ] ) @@ -53,22 +52,18 @@ def agent2(mock_model): @pytest.fixture def swarm(agent1, agent2, hook_provider): - hooks = HookRegistry() - hooks.add_hook(hook_provider) - swarm = Swarm(nodes=[agent1, agent2], hooks=hooks) + swarm = Swarm(nodes=[agent1, agent2], hooks=[hook_provider]) return swarm @pytest.fixture def graph(agent1, agent2, hook_provider): - hooks = HookRegistry() - hooks.add_hook(hook_provider) builder = GraphBuilder() builder.add_node(agent1, "agent1") builder.add_node(agent2, "agent2") builder.add_edge("agent1", "agent2") builder.set_entry_point("agent1") - graph = Graph(nodes=builder.nodes, edges=builder.edges, entry_points=builder.entry_points, hooks=hooks) + graph = Graph(nodes=builder.nodes, edges=builder.edges, entry_points=builder.entry_points, hooks=[hook_provider]) return graph @@ -80,8 +75,8 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider): assert length == 3 assert result.status.value == "completed" - assert next(events) == MultiagentInitializedEvent(source=swarm) - assert next(events) == AfterNodeInvocationEvent(source=swarm, executed_node="agent1") + assert next(events) == MultiAgentInitializedEvent(source=swarm) + assert next(events) == AfterNodeCallEvent(source=swarm, node_id="agent1") assert next(events) == AfterMultiAgentInvocationEvent(source=swarm) @@ -93,7 +88,7 @@ def test_graph_complete_hook_lifecycle(graph, hook_provider): assert length == 4 assert result.status.value == "completed" - assert next(events) == MultiagentInitializedEvent(source=graph) - assert next(events) == AfterNodeInvocationEvent(source=graph, executed_node="agent1") - assert next(events) == AfterNodeInvocationEvent(source=graph, executed_node="agent2") + assert next(events) == MultiAgentInitializedEvent(source=graph) + assert next(events) == AfterNodeCallEvent(source=graph, node_id="agent1") + assert next(events) == AfterNodeCallEvent(source=graph, node_id="agent2") assert next(events) == AfterMultiAgentInvocationEvent(source=graph) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index e2db5b07a..b6c9be27c 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -248,9 +248,9 @@ def test_serialize_node_result_for_persist(agent_result): assert "execution_time" in serialized assert "status" in serialized - # Test with invalid input type should raise TypeError - with pytest.raises(TypeError, match="serialize_node_result_for_persist expects NodeResult"): - MultiAgentBase.serialize_node_result_for_persist(agent, {"agent_outputs": ["test1", "test2"]}) - - with pytest.raises(TypeError, match="serialize_node_result_for_persist expects NodeResult"): - MultiAgentBase.serialize_node_result_for_persist(agent, "simple string") + # Test with NodeResult containing Exception + exception_node_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) + serialized_exception = MultiAgentBase.serialize_node_result_for_persist(agent, exception_node_result) + assert "result" in serialized_exception + assert serialized_exception["result"]["type"] == "exception" + assert serialized_exception["result"]["message"] == "Test error" From 8ed2e21eccc64387cc4b18ad931e93148916481f Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 17 Oct 2025 15:41:09 -0400 Subject: [PATCH 21/27] feat: add BeforeNodeCallEvent to swarm & graph --- src/strands/multiagent/graph.py | 5 ++++- src/strands/multiagent/swarm.py | 2 ++ .../hooks/multiagent_hooks/test_multiagent_hooks.py | 7 +++++-- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 1387621fe..c3b65f736 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -29,6 +29,7 @@ from ..experimental.hooks.multiagent_hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + BeforeNodeCallEvent, MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry @@ -619,12 +620,14 @@ async def _record_failure(exception: Exception) -> None: self.state.results[node.node_id] = fail_result self.hooks.invoke_callbacks(AfterNodeCallEvent(source=self, node_id=node.node_id)) + # This is a placeholder for firing BeforeNodeCallEvent. + self.hooks.invoke_callbacks(BeforeNodeCallEvent(source=self, node_id=node.node_id)) + if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) node.reset_executor_state() # Remove from completed nodes since we're re-executing it self.state.completed_nodes.remove(node) - node.execution_status = Status.EXECUTING logger.debug("node_id=<%s> | executing node", node.node_id) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index abda86f3c..ba9a548f7 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -28,6 +28,7 @@ from ..experimental.hooks.multiagent_hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + BeforeNodeCallEvent, MultiAgentInitializedEvent, ) from ..hooks import HookProvider, HookRegistry @@ -655,6 +656,7 @@ async def _execute_node( self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] ) -> AgentResult: """Execute swarm node.""" + self.hooks.invoke_callbacks(BeforeNodeCallEvent(source=self, node_id=node.node_id)) start_time = time.time() node_name = node.node_id diff --git a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py index 49653296b..61c0a18ec 100644 --- a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py +++ b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py @@ -72,10 +72,11 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider): result = swarm("test task") length, events = hook_provider.get_events() - assert length == 3 + assert length == 4 assert result.status.value == "completed" assert next(events) == MultiAgentInitializedEvent(source=swarm) + assert next(events) == BeforeNodeCallEvent(source=swarm, node_id="agent1") assert next(events) == AfterNodeCallEvent(source=swarm, node_id="agent1") assert next(events) == AfterMultiAgentInvocationEvent(source=swarm) @@ -85,10 +86,12 @@ def test_graph_complete_hook_lifecycle(graph, hook_provider): result = graph("test task") length, events = hook_provider.get_events() - assert length == 4 + assert length == 6 assert result.status.value == "completed" assert next(events) == MultiAgentInitializedEvent(source=graph) + assert next(events) == BeforeNodeCallEvent(source=graph, node_id="agent1") assert next(events) == AfterNodeCallEvent(source=graph, node_id="agent1") + assert next(events) == BeforeNodeCallEvent(source=graph, node_id="agent2") assert next(events) == AfterNodeCallEvent(source=graph, node_id="agent2") assert next(events) == AfterMultiAgentInvocationEvent(source=graph) From d3adef3b109ad056da37e9313ba492dbacef0d4c Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Fri, 17 Oct 2025 15:56:26 -0400 Subject: [PATCH 22/27] fix: fix bad rebase --- src/strands/multiagent/graph.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index c3b65f736..753dc142f 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -665,6 +665,12 @@ async def _record_failure(exception: Exception) -> None: else: agent_response = await node.executor.invoke_async(node_input, invocation_state=invocation_state) + if agent_response.stop_reason == "interrupt": + node.executor.messages.pop() # remove interrupted tool use message + node.executor._interrupt_state.deactivate() + + raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs") + # Extract metrics from agent response usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0) metrics = Metrics(latencyMs=0) From 80c8169f2790c570b51204508500eac9a3f1c5de Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 20 Oct 2025 13:49:12 -0400 Subject: [PATCH 23/27] fix: address comments, move from_dict() to AgentResult, fix docstrings, refactor interface --- src/strands/agent/agent_result.py | 23 ++++++++- .../multiagent_hooks/multiagent_events.py | 12 ++--- src/strands/multiagent/base.py | 44 ++-------------- src/strands/multiagent/graph.py | 43 ++++++++-------- src/strands/multiagent/swarm.py | 21 +++++--- tests/strands/agent/test_agent_result.py | 28 +++++++++++ .../multiagent_hooks/test_multiagent_hooks.py | 50 +++++++++++++++---- tests/strands/multiagent/test_base.py | 7 +-- tests/strands/multiagent/test_graph.py | 2 +- 9 files changed, 139 insertions(+), 91 deletions(-) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index eb9bc4dd9..e408b062e 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -4,7 +4,7 @@ """ from dataclasses import dataclass -from typing import Any, Sequence +from typing import Any, Sequence, cast from ..interrupt import Interrupt from ..telemetry.metrics import EventLoopMetrics @@ -46,3 +46,24 @@ def __str__(self) -> str: if isinstance(item, dict) and "text" in item: result += item.get("text", "") + "\n" return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "AgentResult": + """Rehydrate an AgentResult from persisted JSON. + + Args: + data: Dictionary containing the serialized AgentResult data + + Returns: + AgentResult instance + + Raises: + TypeError: If the data format is invalid + """ + if data.get("type") != "agent_result": + raise TypeError(f"AgentResult.from_dict: unexpected type {data.get('type')!r}") + + message = cast(Message, data.get("message")) + stop_reason = cast(StopReason, data.get("stop_reason")) + + return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) diff --git a/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py b/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py index ce0ddbd95..08b5791ec 100644 --- a/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py +++ b/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py @@ -20,7 +20,7 @@ class MultiAgentInitializedEvent(BaseHookEvent): Attributes: source: The multi-agent orchestrator instance - invocation_state: Configuration that user pass in + invocation_state: Configuration that user passes in """ source: "MultiAgentBase" @@ -34,7 +34,7 @@ class BeforeNodeCallEvent(BaseHookEvent): Attributes: source: The multi-agent orchestrator instance node_id: ID of the node that just completed execution - invocation_state: Configuration that user pass in + invocation_state: Configuration that user passes in """ source: "MultiAgentBase" @@ -49,7 +49,7 @@ class AfterNodeCallEvent(BaseHookEvent): Attributes: source: The multi-agent orchestrator instance node_id: ID of the node that just completed execution - invocation_state: Configuration that user pass in + invocation_state: Configuration that user passes in """ source: "MultiAgentBase" @@ -64,11 +64,11 @@ def should_reverse_callbacks(self) -> bool: @dataclass class BeforeMultiAgentInvocationEvent(BaseHookEvent): - """Event triggered after orchestrator execution completes. This event corresponds to the After event. + """Event triggered before orchestrator execution completes. This event corresponds to the After event. Attributes: source: The multi-agent orchestrator instance - invocation_state: Configuration that user pass in + invocation_state: Configuration that user passes in """ source: "MultiAgentBase" @@ -81,7 +81,7 @@ class AfterMultiAgentInvocationEvent(BaseHookEvent): Attributes: source: The multi-agent orchestrator instance - invocation_state: Configuration that user pass in + invocation_state: Configuration that user passes in """ source: "MultiAgentBase" diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index fcc92dc44..20997488a 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -10,12 +10,11 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from enum import Enum -from typing import Any, Union, cast +from typing import Any, Union from ..agent import AgentResult -from ..telemetry.metrics import EventLoopMetrics -from ..types.content import ContentBlock, Message -from ..types.event_loop import Metrics, StopReason, Usage +from ..types.content import ContentBlock +from ..types.event_loop import Metrics, Usage logger = logging.getLogger(__name__) @@ -97,10 +96,10 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": result: Union[AgentResult, "MultiAgentResult", Exception] if isinstance(raw, dict) and raw.get("type") == "agent_result": - result = NodeResult.agent_result_from_persisted(raw) + result = AgentResult.from_dict(raw) elif isinstance(raw, dict) and raw.get("type") == "exception": result = Exception(str(raw.get("message", "node failed"))) - elif isinstance(raw, dict) and ("results" in raw): + elif isinstance(raw, dict) and raw.get("type") == "multiagent_result": result = MultiAgentResult.from_dict(raw) else: raise TypeError(f"NodeResult.from_dict: unsupported result payload: {raw!r}") @@ -128,28 +127,6 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeResult": execution_count=int(data.get("execution_count", 0)), ) - @classmethod - def agent_result_from_persisted(cls, data: dict[str, Any]) -> AgentResult: - """Rehydrate a minimal AgentResult from persisted JSON. - - Expected shape: - {"type": "agent_result", "message": , "stop_reason": } - """ - if data.get("type") != "agent_result": - raise TypeError(f"agent_result_from_persisted: unexpected type {data.get('type')!r}") - - message = cast(Message, data.get("message")) - stop_reason = cast( - StopReason, - data.get("stop_reason"), - ) - - try: - return AgentResult(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) - except Exception: - logger.debug("AgentResult constructor failed during rehydrating") - raise - @dataclass class MultiAgentResult: @@ -261,14 +238,3 @@ def serialize_state(self) -> dict[str, Any]: def deserialize_state(self, payload: dict[str, Any]) -> None: """Restore orchestrator state from a session dict.""" raise NotImplementedError - - def serialize_node_result_for_persist(self, raw: NodeResult) -> dict[str, Any]: - """Serialize node result for persistence. - - Args: - raw: Raw node result to serialize - - Returns: - JSON-serializable dict representation - """ - return raw.to_dict() diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 753dc142f..720db5c59 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -328,10 +328,10 @@ def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder" return self def set_hook_provider(self, hook_providers: list[HookProvider]) -> "GraphBuilder": - """Set hook provider for the graph. + """Set hook providers for the graph. Args: - hook_providers: SessionManager instance + hook_providers: Customer hooks user passes in """ self._hooks = hook_providers return self @@ -404,7 +404,7 @@ def __init__( node_timeout: Individual node timeout in seconds (default: None - no limit) reset_on_revisit: Whether to reset node state when revisited (default: False) session_manager: Optional session manager for persistence - hooks: Optional hook registry for event handling + hooks: Optional custom hook providers for registry. """ super().__init__() @@ -513,7 +513,9 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self)) + self.hooks.invoke_callbacks( + AfterMultiAgentInvocationEvent(source=self, invocation_state=invocation_state) + ) self._resume_from_persisted = False self._resume_next_nodes.clear() return self._build_result() @@ -562,18 +564,9 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: def _find_newly_ready_nodes(self, completed_batch: list["GraphNode"]) -> list["GraphNode"]: """Find nodes that became ready after the last execution.""" newly_ready = [] - for node in self.nodes.values(): - # Skip nodes already completed unless we’re in feedback-loop mode - if ( - node in self.state.completed_nodes or node.execution_status == Status.COMPLETED - ) and not self.reset_on_revisit: - continue - if node in self.state.failed_nodes: - continue + for _node_id, node in self.nodes.items(): if self._is_node_ready_with_conditions(node, completed_batch): - # Avoid duplicates - if node not in newly_ready: - newly_ready.append(node) + newly_ready.append(node) return newly_ready def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list["GraphNode"]) -> bool: @@ -618,10 +611,14 @@ async def _record_failure(exception: Exception) -> None: node.execution_time = execution_time self.state.failed_nodes.add(node) self.state.results[node.node_id] = fail_result - self.hooks.invoke_callbacks(AfterNodeCallEvent(source=self, node_id=node.node_id)) + self.hooks.invoke_callbacks( + AfterNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) # This is a placeholder for firing BeforeNodeCallEvent. - self.hooks.invoke_callbacks(BeforeNodeCallEvent(source=self, node_id=node.node_id)) + self.hooks.invoke_callbacks( + BeforeNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) if self.reset_on_revisit and node in self.state.completed_nodes: logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id) @@ -699,7 +696,9 @@ async def _record_failure(exception: Exception) -> None: self.state.execution_order.append(node) # Accumulate metrics self._accumulate_metrics(node_result) - self.hooks.invoke_callbacks(AfterNodeCallEvent(source=self, node_id=node.node_id)) + self.hooks.invoke_callbacks( + AfterNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) logger.debug( "node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time @@ -838,9 +837,7 @@ def _to_dict(self) -> dict[str, Any]: "type": "graph", "status": status_str, "completed_nodes": [n.node_id for n in self.state.completed_nodes], - "node_results": { - k: self.serialize_node_result_for_persist(v) for k, v in (self.state.results or {}).items() - }, + "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, "next_node_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.execution_order], @@ -891,7 +888,9 @@ def deserialize_state(self, payload: dict) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ - if payload.get("status") in (Status.COMPLETED.value, "completed"): + if payload.get("status") in (Status.COMPLETED.value, "completed") or ( + payload.get("status") in (Status.FAILED.value, "failed") and not payload.get("next_node_to_execute") + ): # Reset all nodes for node in self.nodes.values(): node.reset_executor_state() diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index ba9a548f7..8a579987c 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -337,9 +337,10 @@ async def invoke_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(source=self)) + self.hooks.invoke_callbacks( + AfterMultiAgentInvocationEvent(source=self, invocation_state=invocation_state) + ) self._resume_from_persisted = False - self._resume_from_completed = False return self._build_result() @@ -617,7 +618,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None: logger.debug("node=<%s> | node execution completed", current_node.node_id) - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=current_node.node_id)) + self.hooks.invoke_callbacks( + AfterNodeCallEvent(self, node_id=current_node.node_id, invocation_state=invocation_state) + ) # Check if the current node is still the same after execution # If it is, then no handoff occurred and we consider the swarm complete @@ -656,7 +659,9 @@ async def _execute_node( self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any] ) -> AgentResult: """Execute swarm node.""" - self.hooks.invoke_callbacks(BeforeNodeCallEvent(source=self, node_id=node.node_id)) + self.hooks.invoke_callbacks( + BeforeNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) + ) start_time = time.time() node_name = node.node_id @@ -728,7 +733,7 @@ async def _execute_node( self.state.results[node_name] = node_result # Persist failure here - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=node_name)) + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=node_name, invocation_state=invocation_state)) raise @@ -782,7 +787,7 @@ def _to_dict(self) -> dict: "type": "swarm", "status": status_str, "node_history": [n.node_id for n in self.state.node_history], - "node_results": {k: self.serialize_node_result_for_persist(v) for k, v in normalized_results.items()}, + "node_results": {k: v.to_dict() for k, v in normalized_results.items()}, "next_node_to_execute": next_nodes, "current_task": self.state.task, "context": { @@ -857,7 +862,9 @@ def deserialize_state(self, payload: dict) -> None: Args: payload: Dictionary containing persisted state data """ - if payload.get("status") in (Status.COMPLETED.value, "completed"): + if payload.get("status") in (Status.COMPLETED.value, "completed") or ( + payload.get("status") in (Status.FAILED.value, "failed") and not payload.get("next_node_to_execute") + ): # Reset all nodes for node in self.nodes.values(): node.reset_executor_state() diff --git a/tests/strands/agent/test_agent_result.py b/tests/strands/agent/test_agent_result.py index 409b08a2d..8806baae0 100644 --- a/tests/strands/agent/test_agent_result.py +++ b/tests/strands/agent/test_agent_result.py @@ -95,3 +95,31 @@ def test__str__non_dict_content(mock_metrics): message_string = str(result) assert message_string == "Valid text\nMore valid text\n" + + +def test_from_dict_valid_data(): + """Test that from_dict works with valid data.""" + data = { + "type": "agent_result", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + result = AgentResult.from_dict(data) + + assert result.message == data["message"] + assert result.stop_reason == data["stop_reason"] + assert isinstance(result.metrics, EventLoopMetrics) + assert result.state == {} + + +def test_from_dict_invalid_type(): + """Test that from_dict raises TypeError for invalid type.""" + data = { + "type": "invalid_type", + "message": {"role": "assistant", "content": [{"text": "Test response"}]}, + "stop_reason": "end_turn", + } + + with pytest.raises(TypeError, match="AgentResult.from_dict: unexpected type 'invalid_type'"): + AgentResult.from_dict(data) diff --git a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py index 61c0a18ec..7b10dd3b8 100644 --- a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py +++ b/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py @@ -75,10 +75,22 @@ def test_swarm_complete_hook_lifecycle(swarm, hook_provider): assert length == 4 assert result.status.value == "completed" - assert next(events) == MultiAgentInitializedEvent(source=swarm) - assert next(events) == BeforeNodeCallEvent(source=swarm, node_id="agent1") - assert next(events) == AfterNodeCallEvent(source=swarm, node_id="agent1") - assert next(events) == AfterMultiAgentInvocationEvent(source=swarm) + events_list = list(events) + + # Check event types and basic properties, ignoring invocation_state + assert isinstance(events_list[0], MultiAgentInitializedEvent) + assert events_list[0].source == swarm + + assert isinstance(events_list[1], BeforeNodeCallEvent) + assert events_list[1].source == swarm + assert events_list[1].node_id == "agent1" + + assert isinstance(events_list[2], AfterNodeCallEvent) + assert events_list[2].source == swarm + assert events_list[2].node_id == "agent1" + + assert isinstance(events_list[3], AfterMultiAgentInvocationEvent) + assert events_list[3].source == swarm def test_graph_complete_hook_lifecycle(graph, hook_provider): @@ -89,9 +101,27 @@ def test_graph_complete_hook_lifecycle(graph, hook_provider): assert length == 6 assert result.status.value == "completed" - assert next(events) == MultiAgentInitializedEvent(source=graph) - assert next(events) == BeforeNodeCallEvent(source=graph, node_id="agent1") - assert next(events) == AfterNodeCallEvent(source=graph, node_id="agent1") - assert next(events) == BeforeNodeCallEvent(source=graph, node_id="agent2") - assert next(events) == AfterNodeCallEvent(source=graph, node_id="agent2") - assert next(events) == AfterMultiAgentInvocationEvent(source=graph) + events_list = list(events) + + # Check event types and basic properties, ignoring invocation_state + assert isinstance(events_list[0], MultiAgentInitializedEvent) + assert events_list[0].source == graph + + assert isinstance(events_list[1], BeforeNodeCallEvent) + assert events_list[1].source == graph + assert events_list[1].node_id == "agent1" + + assert isinstance(events_list[2], AfterNodeCallEvent) + assert events_list[2].source == graph + assert events_list[2].node_id == "agent1" + + assert isinstance(events_list[3], BeforeNodeCallEvent) + assert events_list[3].source == graph + assert events_list[3].node_id == "agent2" + + assert isinstance(events_list[4], AfterNodeCallEvent) + assert events_list[4].source == graph + assert events_list[4].node_id == "agent2" + + assert isinstance(events_list[5], AfterMultiAgentInvocationEvent) + assert events_list[5].source == graph diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index b6c9be27c..f26796ac9 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -235,13 +235,10 @@ def test_multi_agent_result_to_dict(agent_result): def test_serialize_node_result_for_persist(agent_result): """Test serialize_node_result_for_persist method.""" - from unittest.mock import Mock - - agent = Mock(spec=MultiAgentBase) # Test with NodeResult containing AgentResult node_result = NodeResult(result=agent_result) - serialized = MultiAgentBase.serialize_node_result_for_persist(agent, node_result) + serialized = node_result.to_dict() # Should return the to_dict() result assert "result" in serialized @@ -250,7 +247,7 @@ def test_serialize_node_result_for_persist(agent_result): # Test with NodeResult containing Exception exception_node_result = NodeResult(result=Exception("Test error"), status=Status.FAILED) - serialized_exception = MultiAgentBase.serialize_node_result_for_persist(agent, exception_node_result) + serialized_exception = exception_node_result.to_dict() assert "result" in serialized_exception assert serialized_exception["result"]["type"] == "exception" assert serialized_exception["result"]["message"] == "Test error" diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index a31990b4f..6d8e492b5 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1139,7 +1139,7 @@ async def test_self_loop_functionality_without_reset(mock_strands_tracer, mock_u result = await graph.invoke_async("Test self loop without reset") assert result.status == Status.COMPLETED - assert len(result.execution_order) == 1 + assert len(result.execution_order) == 2 mock_strands_tracer.start_multiagent_span.assert_called() mock_use_span.assert_called() From 2ff40353870a5d9a62a85d152ed7893124276c66 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 20 Oct 2025 14:00:18 -0400 Subject: [PATCH 24/27] fix: fix typo and pattern --- src/strands/multiagent/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index 20997488a..f94b914ad 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -147,7 +147,7 @@ class MultiAgentResult: def to_dict(self) -> dict[str, Any]: """Convert MultiAgentResult to JSON-serializable dict.""" return { - "type": "mutiagent_result", + "type": "multiagent_result", "status": self.status.value, "results": {k: v.to_dict() for k, v in self.results.items()}, "accumulated_usage": dict(self.accumulated_usage), @@ -159,6 +159,9 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "MultiAgentResult": """Rehydrate a MultiAgentResult from persisted JSON.""" + if data.get("type") != "multiagent_result": + raise TypeError(f"MultiAgentResult.from_dict: unexpected type {data.get('type')!r}") + results = {k: NodeResult.from_dict(v) for k, v in data.get("results", {}).items()} usage_data = data.get("accumulated_usage", {}) usage = Usage( From 191f5e0913d343a6b6b57744ef5468dd29d91f05 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 20 Oct 2025 14:54:26 -0400 Subject: [PATCH 25/27] fix: rename multiagent dictory --- .../hooks/{multiagent_hooks => multiagent}/__init__.py | 2 +- .../multiagent_events.py => multiagent/events.py} | 0 src/strands/multiagent/graph.py | 2 +- src/strands/multiagent/swarm.py | 2 +- src/strands/session/session_manager.py | 2 +- tests/fixtures/mock_multiagent_hook_provider.py | 2 +- .../hooks/{multiagent_hooks => multiagent}/__init__.py | 0 .../{multiagent_hooks => multiagent}/test_multiagent_events.py | 2 +- .../{multiagent_hooks => multiagent}/test_multiagent_hooks.py | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) rename src/strands/experimental/hooks/{multiagent_hooks => multiagent}/__init__.py (94%) rename src/strands/experimental/hooks/{multiagent_hooks/multiagent_events.py => multiagent/events.py} (100%) rename tests/strands/experimental/hooks/{multiagent_hooks => multiagent}/__init__.py (100%) rename tests/strands/experimental/hooks/{multiagent_hooks => multiagent}/test_multiagent_events.py (97%) rename tests/strands/experimental/hooks/{multiagent_hooks => multiagent}/test_multiagent_hooks.py (98%) diff --git a/src/strands/experimental/hooks/multiagent_hooks/__init__.py b/src/strands/experimental/hooks/multiagent/__init__.py similarity index 94% rename from src/strands/experimental/hooks/multiagent_hooks/__init__.py rename to src/strands/experimental/hooks/multiagent/__init__.py index c3dc793e7..83a62398b 100644 --- a/src/strands/experimental/hooks/multiagent_hooks/__init__.py +++ b/src/strands/experimental/hooks/multiagent/__init__.py @@ -4,7 +4,7 @@ enabling resumable execution after interruptions or failures. """ -from .multiagent_events import ( +from .events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeMultiAgentInvocationEvent, diff --git a/src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py b/src/strands/experimental/hooks/multiagent/events.py similarity index 100% rename from src/strands/experimental/hooks/multiagent_hooks/multiagent_events.py rename to src/strands/experimental/hooks/multiagent/events.py diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 720db5c59..e735f3d88 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -26,7 +26,7 @@ from ..agent import Agent from ..agent.state import AgentState -from ..experimental.hooks.multiagent_hooks import ( +from ..experimental.hooks.multiagent import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 8a579987c..5f6ea6878 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -25,7 +25,7 @@ from ..agent import Agent, AgentResult from ..agent.state import AgentState -from ..experimental.hooks.multiagent_hooks import ( +from ..experimental.hooks.multiagent import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index c54175062..21b6cec40 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any -from ..experimental.hooks.multiagent_hooks.multiagent_events import ( +from ..experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, MultiAgentInitializedEvent, diff --git a/tests/fixtures/mock_multiagent_hook_provider.py b/tests/fixtures/mock_multiagent_hook_provider.py index 12679e91d..727d28a48 100644 --- a/tests/fixtures/mock_multiagent_hook_provider.py +++ b/tests/fixtures/mock_multiagent_hook_provider.py @@ -1,6 +1,6 @@ from typing import Iterator, Literal, Tuple, Type -from strands.experimental.hooks.multiagent_hooks import ( +from strands.experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, diff --git a/tests/strands/experimental/hooks/multiagent_hooks/__init__.py b/tests/strands/experimental/hooks/multiagent/__init__.py similarity index 100% rename from tests/strands/experimental/hooks/multiagent_hooks/__init__.py rename to tests/strands/experimental/hooks/multiagent/__init__.py diff --git a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_events.py b/tests/strands/experimental/hooks/multiagent/test_multiagent_events.py similarity index 97% rename from tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_events.py rename to tests/strands/experimental/hooks/multiagent/test_multiagent_events.py index d7f91d198..6d2b22955 100644 --- a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_events.py +++ b/tests/strands/experimental/hooks/multiagent/test_multiagent_events.py @@ -4,7 +4,7 @@ import pytest -from strands.experimental.hooks.multiagent_hooks.multiagent_events import ( +from strands.experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, MultiAgentInitializedEvent, diff --git a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py b/tests/strands/experimental/hooks/multiagent/test_multiagent_hooks.py similarity index 98% rename from tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py rename to tests/strands/experimental/hooks/multiagent/test_multiagent_hooks.py index 7b10dd3b8..cfdbfe30f 100644 --- a/tests/strands/experimental/hooks/multiagent_hooks/test_multiagent_hooks.py +++ b/tests/strands/experimental/hooks/multiagent/test_multiagent_hooks.py @@ -1,7 +1,7 @@ import pytest from strands import Agent -from strands.experimental.hooks.multiagent_hooks import ( +from strands.experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, BeforeNodeCallEvent, From a1e10ed2ce92754474aa6cfc4345f47c74ba3a37 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Mon, 20 Oct 2025 18:15:31 -0400 Subject: [PATCH 26/27] fix: address PR comments --- src/strands/multiagent/graph.py | 19 +++++------- src/strands/multiagent/swarm.py | 18 ++++++------ src/strands/session/file_session_manager.py | 13 +++++---- src/strands/session/s3_session_manager.py | 16 +++++----- src/strands/session/session_manager.py | 29 ++++++------------- tests/strands/multiagent/test_base.py | 3 -- tests/strands/multiagent/test_graph.py | 2 +- tests/strands/multiagent/test_swarm.py | 2 +- .../session/test_file_session_manager.py | 13 +++++++-- .../session/test_s3_session_manager.py | 13 +++++++-- 10 files changed, 61 insertions(+), 67 deletions(-) diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index e735f3d88..4d336b73c 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -327,13 +327,13 @@ def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder" self._session_manager = session_manager return self - def set_hook_provider(self, hook_providers: list[HookProvider]) -> "GraphBuilder": + def set_hook_provider(self, hooks: list[HookProvider]) -> "GraphBuilder": """Set hook providers for the graph. Args: - hook_providers: Customer hooks user passes in + hooks: Customer hooks user passes in """ - self._hooks = hook_providers + self._hooks = hooks return self def build(self) -> "Graph": @@ -483,8 +483,6 @@ async def invoke_async( start_time=start_time, ) else: - if isinstance(self.state.task, (str, list)) and not self.state.task: - self.state.task = task self.state.status = Status.EXECUTING self.state.start_time = time.time() span = self.tracer.start_multiagent_span(task, "graph") @@ -593,7 +591,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None: """Execute a single node with error handling and timeout protection.""" - # Reset the node's state if reset_on_revisit is enabled and it's being revisited + # Reset the node's state if reset_on_revisit is enabled, and it's being revisited async def _record_failure(exception: Exception) -> None: execution_time = round((time.time() - start_time) * 1000) @@ -615,7 +613,6 @@ async def _record_failure(exception: Exception) -> None: AfterNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) ) - # This is a placeholder for firing BeforeNodeCallEvent. self.hooks.invoke_callbacks( BeforeNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) ) @@ -838,7 +835,7 @@ def _to_dict(self) -> dict[str, Any]: "status": status_str, "completed_nodes": [n.node_id for n in self.state.completed_nodes], "node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()}, - "next_node_to_execute": next_nodes, + "next_nodes_to_execute": next_nodes, "current_task": self.state.task, "execution_order": [n.node_id for n in self.state.execution_order], } @@ -888,9 +885,7 @@ def deserialize_state(self, payload: dict) -> None: payload: Dictionary containing persisted state data including status, completed nodes, results, and next nodes to execute. """ - if payload.get("status") in (Status.COMPLETED.value, "completed") or ( - payload.get("status") in (Status.FAILED.value, "failed") and not payload.get("next_node_to_execute") - ): + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): node.reset_executor_state() @@ -903,7 +898,7 @@ def deserialize_state(self, payload: dict) -> None: self._from_dict(payload) self._resume_from_persisted = True - next_node_ids = payload.get("next_node_to_execute") or [] + next_node_ids = payload.get("next_nodes_to_execute") or [] mapped = self._map_node_ids(next_node_ids) valid_ready: list[GraphNode] = [] completed = set(self.state.completed_nodes) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index 5f6ea6878..b74051229 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -492,7 +492,7 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st # Persist handoff msg incase we lose it. if self.session_manager is not None: try: - self.session_manager.write_multi_agent_json(self.serialize_state()) + self.session_manager.write_multi_agent_json(self) except Exception as e: logger.warning("Failed to persist swarm state after handoff: %s", e) raise @@ -663,7 +663,7 @@ async def _execute_node( BeforeNodeCallEvent(source=self, node_id=node.node_id, invocation_state=invocation_state) ) start_time = time.time() - node_name = node.node_id + node_id = node.node_id try: # Prepare context for node @@ -705,7 +705,7 @@ async def _execute_node( ) # Store result in state - self.state.results[node_name] = node_result + self.state.results[node_id] = node_result # Accumulate metrics self._accumulate_metrics(node_result) @@ -717,7 +717,7 @@ async def _execute_node( except Exception as e: execution_time = round((time.time() - start_time) * 1000) - logger.exception("node=<%s> | node execution failed", node_name) + logger.exception("node=<%s> | node execution failed", node_id) # Create a NodeResult for the failed node node_result = NodeResult( @@ -730,10 +730,10 @@ async def _execute_node( ) # Store result in state - self.state.results[node_name] = node_result + self.state.results[node_id] = node_result # Persist failure here - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=node_name, invocation_state=invocation_state)) + self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node_id=node_id, invocation_state=invocation_state)) raise @@ -788,7 +788,7 @@ def _to_dict(self) -> dict: "status": status_str, "node_history": [n.node_id for n in self.state.node_history], "node_results": {k: v.to_dict() for k, v in normalized_results.items()}, - "next_node_to_execute": next_nodes, + "next_nodes_to_execute": next_nodes, "current_task": self.state.task, "context": { "shared_context": getattr(self.state.shared_context, "context", {}) or {}, @@ -831,7 +831,7 @@ def _from_dict(self, payload: dict) -> None: self.state.task = payload.get("current_task", self.state.task) # Determine current node (if executing) - next_ids = list(payload.get("next_node_to_execute") or []) + next_ids = list(payload.get("next_nodes_to_execute") or []) if next_ids: nid = next_ids[0] found_node = self.nodes.get(nid) @@ -863,7 +863,7 @@ def deserialize_state(self, payload: dict) -> None: payload: Dictionary containing persisted state data """ if payload.get("status") in (Status.COMPLETED.value, "completed") or ( - payload.get("status") in (Status.FAILED.value, "failed") and not payload.get("next_node_to_execute") + payload.get("status") in (Status.FAILED.value, "failed") and not payload.get("next_nodes_to_execute") ): # Reset all nodes for node in self.nodes.values(): diff --git a/src/strands/session/file_session_manager.py b/src/strands/session/file_session_manager.py index 0089c47c4..954ccc416 100644 --- a/src/strands/session/file_session_manager.py +++ b/src/strands/session/file_session_manager.py @@ -6,7 +6,7 @@ import shutil import tempfile from datetime import datetime, timezone -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from .. import _identifier from ..types.exceptions import SessionException @@ -14,6 +14,9 @@ from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" @@ -119,8 +122,6 @@ def _write_file(self, path: str, data: dict[str, Any]) -> None: tmp = f"{path}.tmp" with open(tmp, "w", encoding="utf-8", newline="\n") as f: json.dump(data, f, indent=2, ensure_ascii=False) - f.flush() - os.fsync(f.fileno()) os.replace(tmp, path) def create_session(self, session: Session, **kwargs: Any) -> Session: @@ -254,13 +255,13 @@ def list_messages( return messages - def write_multi_agent_json(self, state: dict[str, Any], **kwargs: Any) -> None: + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: """Write multi-agent state to filesystem. Args: - state: Multi-agent state dictionary to persist - **kwargs: Additional keyword arguments for future extensibility + source: Multi-agent source object to persist """ + state = source.serialize_state() state_path = os.path.join(self._get_session_path(self.session_id), "multi_agent_state.json") self._write_file(state_path, state) diff --git a/src/strands/session/s3_session_manager.py b/src/strands/session/s3_session_manager.py index b6e18f2a5..a28d1b2d7 100644 --- a/src/strands/session/s3_session_manager.py +++ b/src/strands/session/s3_session_manager.py @@ -2,8 +2,7 @@ import json import logging -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast import boto3 from botocore.config import Config as BotocoreConfig @@ -15,6 +14,9 @@ from .repository_session_manager import RepositorySessionManager from .session_repository import SessionRepository +if TYPE_CHECKING: + from ..multiagent.base import MultiAgentBase + logger = logging.getLogger(__name__) SESSION_PREFIX = "session_" @@ -303,21 +305,17 @@ def list_messages( except ClientError as e: raise SessionException(f"S3 error reading messages: {e}") from e - def write_multi_agent_json(self, state: dict[str, Any]) -> None: + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: """Write multi-agent state to S3. Args: - state: Multi-agent state dictionary to persist + source: Multi-agent source object to persist """ session_prefix = self._get_session_path(self.session_id) state_key = self._join_key(session_prefix, "multi_agent_state.json") + state = source.serialize_state() self._write_s3_object(state_key, state) - session_key = self._join_key(session_prefix, "session.json") - metadata = self._read_s3_object(session_key) or {} - metadata["updated_at"] = datetime.now(timezone.utc).isoformat() - self._write_s3_object(session_key, metadata) - def read_multi_agent_json(self) -> dict[str, Any]: """Read multi-agent state from S3. diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index 21b6cec40..8b84e0790 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -1,7 +1,6 @@ """Session manager interface for agent session management.""" import logging -import threading from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any @@ -38,10 +37,13 @@ def __init__(self, session_type: SessionType = SessionType.AGENT) -> None: session_type: Type of session (AGENT or MULTI_AGENT) """ self.session_type: SessionType = session_type - self._lock = threading.RLock() def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: """Register hooks for persisting the agent to the session.""" + if not hasattr(self, "session_type"): + self.session_type = SessionType.AGENT + logger.debug("Session type not set, defaulting to AGENT") + if self.session_type == SessionType.AGENT: # After the normal Agent initialization behavior, call the session initialize function to restore the agent registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent)) @@ -57,9 +59,9 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: elif self.session_type == SessionType.MULTI_AGENT: registry.add_callback(MultiAgentInitializedEvent, self._on_multiagent_initialized) - registry.add_callback(AfterNodeCallEvent, lambda event: self._persist_multi_agent_state(event.source)) + registry.add_callback(AfterNodeCallEvent, lambda event: self.write_multi_agent_json(event.source)) registry.add_callback( - AfterMultiAgentInvocationEvent, lambda event: self._persist_multi_agent_state(event.source) + AfterMultiAgentInvocationEvent, lambda event: self.write_multi_agent_json(event.source) ) @abstractmethod @@ -100,21 +102,11 @@ def initialize(self, agent: "Agent", **kwargs: Any) -> None: **kwargs: Additional keyword arguments for future extensibility. """ - def _persist_multi_agent_state(self, source: "MultiAgentBase") -> None: - """Thread-safe persistence of multi-agent state. - - Args: - source: Multi-agent orchestrator to persist - """ - with self._lock: - state = source.serialize_state() - self.write_multi_agent_json(state) - - def write_multi_agent_json(self, state: dict[str, Any]) -> None: + def write_multi_agent_json(self, source: "MultiAgentBase") -> None: """Write multi-agent state to persistent storage. Args: - state: Multi-agent state dictionary to persist + source: Multi-agent source object to persist """ raise NotImplementedError( f"{self.__class__.__name__} does not support multi-agent persistence " @@ -142,7 +134,4 @@ def _on_multiagent_initialized(self, event: MultiAgentInitializedEvent) -> None: if payload: source.deserialize_state(payload) else: - try: - self._persist_multi_agent_state(source) - except NotImplementedError: - pass + self.write_multi_agent_json(source) diff --git a/tests/strands/multiagent/test_base.py b/tests/strands/multiagent/test_base.py index f26796ac9..0da071afa 100644 --- a/tests/strands/multiagent/test_base.py +++ b/tests/strands/multiagent/test_base.py @@ -152,9 +152,6 @@ def serialize_state(self) -> dict: def deserialize_state(self, payload: dict) -> None: pass - def attempt_resume(self, payload: dict) -> None: - pass - # Should not raise an exception - __call__ is provided by base class agent = CompleteMultiAgent() assert isinstance(agent, MultiAgentBase) diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 6d8e492b5..1e44d0022 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -1373,7 +1373,7 @@ async def test_graph_persisted(mock_strands_tracer, mock_use_span): "node_results": {}, "current_task": "persisted task", "execution_order": [], - "next_node_to_execute": ["test_node"], + "next_nodes_to_execute": ["test_node"], } graph._from_dict(persisted_state) diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index f574c4ef2..2960d01fc 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -601,7 +601,7 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): "node_history": [], "node_results": {}, "current_task": "persisted task", - "next_node_to_execute": ["test_agent"], + "next_nodes_to_execute": ["test_agent"], "context": {"shared_context": {"test_agent": {"key": "value"}}, "handoff_message": "test handoff"}, } diff --git a/tests/strands/session/test_file_session_manager.py b/tests/strands/session/test_file_session_manager.py index 6db7c1e6d..61b20aba4 100644 --- a/tests/strands/session/test_file_session_manager.py +++ b/tests/strands/session/test_file_session_manager.py @@ -414,13 +414,20 @@ def test_write_read_multi_agent_json(file_manager, sample_session): """Test writing and reading multi-agent state.""" file_manager.create_session(sample_session) + # Create mock MultiAgentBase object + class MockMultiAgent: + def serialize_state(self): + return {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + + mock_agent = MockMultiAgent() + expected_state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} + # Write multi-agent state - state = {"type": "graph", "status": "completed", "nodes": ["node1", "node2"]} - file_manager.write_multi_agent_json(state) + file_manager.write_multi_agent_json(mock_agent) # Read multi-agent state result = file_manager.read_multi_agent_json() - assert result == state + assert result == expected_state def test_read_multi_agent_json_nonexistent(file_manager): diff --git a/tests/strands/session/test_s3_session_manager.py b/tests/strands/session/test_s3_session_manager.py index a6aa0b9b4..d5f2b7d97 100644 --- a/tests/strands/session/test_s3_session_manager.py +++ b/tests/strands/session/test_s3_session_manager.py @@ -380,11 +380,18 @@ def test_write_read_multi_agent_json(s3_manager, sample_session): """Test multi-agent state persistence.""" s3_manager.create_session(sample_session) - state = {"type": "graph", "status": "completed"} - s3_manager.write_multi_agent_json(state) + # Create mock MultiAgentBase object + class MockMultiAgent: + def serialize_state(self): + return {"type": "graph", "status": "completed"} + + mock_agent = MockMultiAgent() + expected_state = {"type": "graph", "status": "completed"} + + s3_manager.write_multi_agent_json(mock_agent) result = s3_manager.read_multi_agent_json() - assert result == state + assert result == expected_state def test_read_multi_agent_json_nonexistent(s3_manager): From 734c59bcd3af9ad8c2d4f5df268fb8d5d85137b3 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Tue, 21 Oct 2025 17:42:53 -0400 Subject: [PATCH 27/27] fix: address comment --- src/strands/agent/agent_result.py | 16 +++++++++++++ .../experimental/hooks/multiagent/events.py | 5 ++++ src/strands/multiagent/base.py | 6 +---- src/strands/multiagent/graph.py | 16 ++++++------- src/strands/multiagent/swarm.py | 23 ++++++++----------- 5 files changed, 39 insertions(+), 27 deletions(-) diff --git a/src/strands/agent/agent_result.py b/src/strands/agent/agent_result.py index e408b062e..b63b4329e 100644 --- a/src/strands/agent/agent_result.py +++ b/src/strands/agent/agent_result.py @@ -67,3 +67,19 @@ def from_dict(cls, data: dict[str, Any]) -> "AgentResult": stop_reason = cast(StopReason, data.get("stop_reason")) return cls(message=message, stop_reason=stop_reason, metrics=EventLoopMetrics(), state={}) + + @classmethod + def to_dict(cls, data: "AgentResult") -> dict[str, Any]: + """Convert an AgentResult to JSON-serializable dictionary. + + Args: + data: AgentResult instance to serialize + + Returns: + Dictionary containing serialized AgentResult data + """ + return { + "type": "agent_result", + "message": data.message, + "stop_reason": data.stop_reason, + } diff --git a/src/strands/experimental/hooks/multiagent/events.py b/src/strands/experimental/hooks/multiagent/events.py index 08b5791ec..257cf57f0 100644 --- a/src/strands/experimental/hooks/multiagent/events.py +++ b/src/strands/experimental/hooks/multiagent/events.py @@ -86,3 +86,8 @@ class AfterMultiAgentInvocationEvent(BaseHookEvent): source: "MultiAgentBase" invocation_state: dict[str, Any] | None = None + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True diff --git a/src/strands/multiagent/base.py b/src/strands/multiagent/base.py index f94b914ad..bfe4b3b3f 100644 --- a/src/strands/multiagent/base.py +++ b/src/strands/multiagent/base.py @@ -68,11 +68,7 @@ def to_dict(self) -> dict[str, Any]: result_data: dict[str, Any] = {"type": "exception", "message": str(self.result)} elif isinstance(self.result, AgentResult): # Serialize AgentResult without state field - result_data = { - "type": "agent_result", - "stop_reason": self.result.stop_reason, - "message": self.result.message, - } + result_data = AgentResult.to_dict(self.result) elif isinstance(self.result, MultiAgentResult): result_data = self.result.to_dict() else: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index 4d336b73c..278b354df 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -327,7 +327,7 @@ def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder" self._session_manager = session_manager return self - def set_hook_provider(self, hooks: list[HookProvider]) -> "GraphBuilder": + def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder": """Set hook providers for the graph. Args: @@ -428,7 +428,7 @@ def __init__( for hook in hooks: self.hooks.add_hook(hook) # Resume flags - self._resume_from_persisted = False + self._resume_from_session = False self._resume_next_nodes: list[GraphNode] = [] self.hooks.invoke_callbacks(MultiAgentInitializedEvent(source=self)) @@ -471,7 +471,7 @@ async def invoke_async( logger.debug("task=<%s> | starting graph execution", task) - if not self._resume_from_persisted: + if not self._resume_from_session: start_time = time.time() # Initialize state self.state = GraphState( @@ -514,7 +514,7 @@ async def invoke_async( self.hooks.invoke_callbacks( AfterMultiAgentInvocationEvent(source=self, invocation_state=invocation_state) ) - self._resume_from_persisted = False + self._resume_from_session = False self._resume_next_nodes.clear() return self._build_result() @@ -532,7 +532,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None: async def _execute_graph(self, invocation_state: dict[str, Any]) -> None: """Unified execution flow with conditional routing.""" - ready_nodes = self._resume_next_nodes if self._resume_from_persisted else list(self.entry_points) + ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points) while ready_nodes: # Check execution limits before continuing @@ -891,12 +891,12 @@ def deserialize_state(self, payload: dict) -> None: node.reset_executor_state() # Reset graph state self.state = GraphState() - self._resume_from_persisted = False + self._resume_from_session = False return try: self._from_dict(payload) - self._resume_from_persisted = True + self._resume_from_session = True next_node_ids = payload.get("next_nodes_to_execute") or [] mapped = self._map_node_ids(next_node_ids) @@ -921,6 +921,6 @@ def deserialize_state(self, payload: dict) -> None: logger.debug("Resumed from persisted state. Next nodes: %s", [n.node_id for n in self._resume_next_nodes]) except Exception: logger.exception("Failed to apply resume payload") - self._resume_from_persisted = False + self._resume_from_session = False self._resume_next_nodes.clear() raise diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index b74051229..c538cf8f3 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -260,7 +260,7 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - self._resume_from_persisted = False + self._resume_from_session = False self.hooks.invoke_callbacks(MultiAgentInitializedEvent(source=self)) @@ -302,8 +302,7 @@ async def invoke_async( logger.debug("starting swarm execution") - # If resume - if not self._resume_from_persisted: + if not self._resume_from_session: initial_node = self._initial_node() self.state = SwarmState( @@ -313,8 +312,6 @@ async def invoke_async( shared_context=self.shared_context, ) else: - if isinstance(self.state.task, (str, list)) and not self.state.task: - self.state.task = task self.state.completion_status = Status.EXECUTING self.state.start_time = time.time() @@ -340,7 +337,7 @@ async def invoke_async( self.hooks.invoke_callbacks( AfterMultiAgentInvocationEvent(source=self, invocation_state=invocation_state) ) - self._resume_from_persisted = False + self._resume_from_session = False return self._build_result() @@ -670,6 +667,9 @@ async def _execute_node( context_text = self._build_node_input(node) node_input = [ContentBlock(text=f"Context:\n{context_text}\n\n")] + # Clear handoff message after it's been included in context + self.state.handoff_message = None + if not isinstance(task, str): # Include additional ContentBlocks in node input node_input = node_input + task @@ -710,9 +710,6 @@ async def _execute_node( # Accumulate metrics self._accumulate_metrics(node_result) - # Clear handoff message after it's been included in context - self.state.handoff_message = None - return result except Exception as e: @@ -862,9 +859,7 @@ def deserialize_state(self, payload: dict) -> None: Args: payload: Dictionary containing persisted state data """ - if payload.get("status") in (Status.COMPLETED.value, "completed") or ( - payload.get("status") in (Status.FAILED.value, "failed") and not payload.get("next_nodes_to_execute") - ): + if not payload.get("next_nodes_to_execute"): # Reset all nodes for node in self.nodes.values(): node.reset_executor_state() @@ -874,11 +869,11 @@ def deserialize_state(self, payload: dict) -> None: task="", completion_status=Status.PENDING, ) - self._resume_from_persisted = False + self._resume_from_session = False return else: self._from_dict(payload) - self._resume_from_persisted = True + self._resume_from_session = True logger.debug( "Resumed from persisted state. Current node: %s", self.state.current_node.node_id if self.state.current_node else "None",