Skip to content

Commit bf92e79

Browse files
committed
Add some pr feedback
1 parent 3b81f1b commit bf92e79

File tree

14 files changed

+126
-236
lines changed

14 files changed

+126
-236
lines changed

src/strands/agent/agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def __init__(
322322

323323
# Setup session callback handler if session is enabled
324324
if self.session_manager:
325-
self.session_manager.initialize_agent(self)
325+
self.session_manager.initialize(self)
326326

327327
@property
328328
def tool(self) -> ToolCaller:
@@ -500,7 +500,7 @@ def _run_loop(self, prompt: str, kwargs: dict[str, Any]) -> Generator[dict[str,
500500

501501
# Save message if session manager is available
502502
if self.session_manager:
503-
self.session_manager.append_message_to_agent_session(self, new_message)
503+
self.session_manager.append_message(self, new_message)
504504

505505
# Execute the event loop cycle with retry logic for context limits
506506
yield from self._execute_event_loop_cycle(kwargs)
@@ -597,10 +597,10 @@ def _record_tool_execution(
597597

598598
# Save to conversation manager if available
599599
if self.session_manager:
600-
self.session_manager.append_message_to_agent_session(self, user_msg)
601-
self.session_manager.append_message_to_agent_session(self, tool_use_msg)
602-
self.session_manager.append_message_to_agent_session(self, tool_result_msg)
603-
self.session_manager.append_message_to_agent_session(self, assistant_msg)
600+
self.session_manager.append_message(self, user_msg)
601+
self.session_manager.append_message(self, tool_use_msg)
602+
self.session_manager.append_message(self, tool_result_msg)
603+
self.session_manager.append_message(self, assistant_msg)
604604

605605
def _start_agent_trace_span(self, prompt: str) -> None:
606606
"""Starts a trace span for the agent.

src/strands/session/agent_session_manager.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""File-based implementation of session manager."""
22

33
import logging
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import TYPE_CHECKING, Any, List
5+
from uuid import uuid4
56

67
from ..agent.state import AgentState
78
from ..handlers.callback_handler import CompositeCallbackHandler
89
from ..types.content import Message
9-
from .exceptions import SessionException
10-
from .file_session_dao import FileSessionDAO
10+
from ..types.exceptions import SessionException
1111
from .session_dao import SessionDAO
1212
from .session_manager import SessionManager
1313
from .session_models import Session, SessionAgent, SessionMessage, SessionType
@@ -28,13 +28,13 @@ class AgentSessionManager(SessionManager):
2828
def __init__(
2929
self,
3030
session_id: str,
31-
session_dao: Optional[SessionDAO] = None,
31+
session_dao: SessionDAO,
3232
):
3333
"""Initialize the FileSessionManager."""
34-
self.session_dao = session_dao or FileSessionDAO()
34+
self.session_dao = session_dao
3535
self.session_id = session_id
3636

37-
def append_message_to_agent_session(self, agent: "Agent", message: Message) -> None:
37+
def append_message(self, agent: "Agent", message: Message) -> None:
3838
"""Append a message to the agent's session.
3939
4040
Args:
@@ -56,7 +56,7 @@ def append_message_to_agent_session(self, agent: "Agent", message: Message) -> N
5656
),
5757
)
5858

59-
def initialize_agent(self, agent: "Agent") -> None:
59+
def initialize(self, agent: "Agent") -> None:
6060
"""Restore agent data from the current session.
6161
6262
Args:
@@ -65,20 +65,33 @@ def initialize_agent(self, agent: "Agent") -> None:
6565
Raises:
6666
SessionException: If restore operation fails
6767
"""
68-
if agent.id is None:
69-
raise ValueError("`agent.id` must be set before initializing session.")
70-
7168
try:
7269
# Try to read existing session
7370
session = self.session_dao.read_session(self.session_id)
7471

72+
if agent.id is None:
73+
agents: List[SessionAgent] = self.session_dao.list_agents(self.session_id)
74+
if len(agents) == 0:
75+
agent_id = str(uuid4())
76+
if len(agents) == 1:
77+
agent_id = agents[0].agent_id
78+
logger.debug(
79+
"session_id=<%s> | agent_id=<%s> | Restoring agent data from session", self.session_id, agent_id
80+
)
81+
else:
82+
raise ValueError(
83+
"If there is more than one agent in a session, agent.agent_id must be set manually."
84+
)
85+
else:
86+
if agent.id not in [agent.agent_id for agent in self.session_dao.list_agents(self.session_id)]:
87+
raise ValueError(f"Agent {agent.id} not found in session {self.session_id}")
88+
agent_id = agent.id
89+
7590
if session.session_type != SessionType.AGENT:
7691
raise ValueError(f"Invalid session type: {session.session_type}")
7792

78-
if agent.id not in [agent.agent_id for agent in self.session_dao.list_agents(self.session_id)]:
79-
raise ValueError(f"Agent {agent.id} not found in session {self.session_id}")
80-
8193
# Initialize agent
94+
agent.id = agent_id
8295
agent.messages = [
8396
session_message.to_message()
8497
for session_message in self.session_dao.list_messages(self.session_id, agent.id)
@@ -87,8 +100,12 @@ def initialize_agent(self, agent: "Agent") -> None:
87100

88101
except SessionException:
89102
# Session doesn't exist, create new one
90-
logger.debug("Session not found, creating new session")
103+
logger.debug("session_id=<%s> | Session not found, creating new session")
91104
# Session doesn't exist, create new one
105+
if agent.id is None:
106+
agent_id = str(uuid4())
107+
logger.debug("agent_id=<%s> | Creating agent_id for agent since none was set.", agent_id)
108+
agent.id = agent_id
92109
session = Session(session_id=self.session_id, session_type=SessionType.AGENT)
93110
session_agent = SessionAgent(
94111
agent_id=agent.id,
@@ -110,7 +127,7 @@ def session_callback(**kwargs: Any) -> None:
110127
# Handle message persistence
111128
if "message" in kwargs:
112129
message = kwargs["message"]
113-
self.append_message_to_agent_session(kwargs["agent"], message)
130+
self.append_message(kwargs["agent"], message)
114131
except Exception as e:
115132
logger.error("Persistence operation failed", e)
116133

src/strands/session/exceptions.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/strands/session/file_session_dao.py renamed to src/strands/session/file_session_manager.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""File-based session DAO for local filesystem storage."""
1+
"""File-based session manager for local filesystem storage."""
22

33
import json
44
import os
@@ -7,7 +7,8 @@
77
from datetime import datetime, timezone
88
from typing import Any, Dict, List, Optional, cast
99

10-
from .exceptions import SessionException
10+
from ..types.exceptions import SessionException
11+
from .agent_session_manager import AgentSessionManager
1112
from .session_dao import SessionDAO
1213
from .session_models import Session, SessionAgent, SessionMessage
1314

@@ -16,18 +17,21 @@
1617
MESSAGE_PREFIX = "message_"
1718

1819

19-
class FileSessionDAO(SessionDAO):
20-
"""File-based session DAO for local filesystem storage."""
20+
class FileSessionManager(AgentSessionManager, SessionDAO):
21+
"""File-based session manager for local filesystem storage."""
2122

22-
def __init__(self, storage_dir: Optional[str] = None):
23-
"""Initialize FileSessionDAO with filesystem storage.
23+
def __init__(self, session_id: str, storage_dir: Optional[str] = None):
24+
"""Initialize FileSession with filesystem storage.
2425
2526
Args:
27+
session_id: ID for the session
2628
storage_dir: Directory for local filesystem storage (defaults to temp dir)
2729
"""
2830
self.storage_dir = storage_dir or os.path.join(tempfile.gettempdir(), "strands/sessions")
2931
os.makedirs(self.storage_dir, exist_ok=True)
3032

33+
super().__init__(session_id=session_id, session_dao=self)
34+
3135
def _get_session_path(self, session_id: str) -> str:
3236
"""Get session directory path."""
3337
return os.path.join(self.storage_dir, f"{SESSION_PREFIX}{session_id}")
Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""S3-based session DAO for cloud storage."""
1+
"""S3-based session manager for AWS S3 cloud storage."""
22

33
import json
44
from typing import Any, Dict, List, Optional, cast
@@ -7,7 +7,8 @@
77
from botocore.config import Config as BotocoreConfig
88
from botocore.exceptions import ClientError
99

10-
from .exceptions import SessionException
10+
from ..types.exceptions import SessionException
11+
from .agent_session_manager import AgentSessionManager
1112
from .session_dao import SessionDAO
1213
from .session_models import Session, SessionAgent, SessionMessage
1314

@@ -16,20 +17,22 @@
1617
MESSAGE_PREFIX = "message_"
1718

1819

19-
class S3SessionDAO(SessionDAO):
20-
"""S3-based session DAO for cloud storage."""
20+
class S3SessionManager(AgentSessionManager, SessionDAO):
21+
"""S3-based session manager for cloud storage."""
2122

2223
def __init__(
2324
self,
25+
session_id: str,
2426
bucket: str,
2527
prefix: str = "",
2628
boto_session: Optional[boto3.Session] = None,
2729
boto_client_config: Optional[BotocoreConfig] = None,
2830
region_name: Optional[str] = None,
2931
):
30-
"""Initialize S3SessionDAO with S3 storage.
32+
"""Initialize S3SessionManager with S3 storage.
3133
3234
Args:
35+
session_id: ID for the session
3336
bucket: S3 bucket name (required)
3437
prefix: S3 key prefix for storage organization
3538
boto_session: Optional boto3 session
@@ -55,6 +58,8 @@ def __init__(
5558

5659
self.client = session.client(service_name="s3", config=client_config)
5760

61+
super().__init__(session_id=session_id, session_dao=self)
62+
5863
def _get_session_path(self, session_id: str) -> str:
5964
"""Get session S3 prefix."""
6065
return f"{self.prefix}{SESSION_PREFIX}{session_id}/"

0 commit comments

Comments
 (0)