Skip to content

Commit bb31688

Browse files
authored
Merge branch 'main' into conversation-manager-session
2 parents 4784196 + 4167c5c commit bb31688

File tree

20 files changed

+441
-169
lines changed

20 files changed

+441
-169
lines changed

src/strands/agent/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,10 @@ async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: A
479479
prompt: User input as text or list of ContentBlock objects for multi-modal content.
480480
**kwargs: Additional parameters to pass to the event loop.
481481
482-
Returns:
482+
Yields:
483483
An async iterator that yields events. Each event is a dictionary containing
484484
information about the current state of processing, such as:
485+
485486
- data: Text content being generated
486487
- complete: Whether this is the final chunk
487488
- current_tool_use: Information about tools being executed

src/strands/event_loop/streaming.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ async def process_stream(chunks: AsyncIterable[StreamEvent]) -> AsyncGenerator[d
253253
Args:
254254
chunks: The chunks of the response stream from the model.
255255
256-
Returns:
256+
Yields:
257257
The reason for stopping, the constructed message, and the usage metrics.
258258
"""
259259
stop_reason: StopReason = "end_turn"
@@ -306,7 +306,7 @@ async def stream_messages(
306306
messages: List of messages to send.
307307
tool_specs: The list of tool specs.
308308
309-
Returns:
309+
Yields:
310310
The reason for stopping, the final message, and the usage metrics
311311
"""
312312
logger.debug("model=<%s> | streaming messages", model)

src/strands/models/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def stream(
7474
"""Stream conversation with the model.
7575
7676
This method handles the full lifecycle of conversing with the model:
77+
7778
1. Format the messages, tool specs, and configuration into a streaming request
7879
2. Send the request to the model
7980
3. Yield the formatted message chunks

src/strands/multiagent/graph.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from dataclasses import dataclass, field
2222
from typing import Any, Callable, Tuple, cast
2323

24+
from opentelemetry import trace as trace_api
25+
2426
from ..agent import Agent, AgentResult
27+
from ..telemetry import get_tracer
2528
from ..types.content import ContentBlock
2629
from ..types.event_loop import Metrics, Usage
2730
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
@@ -249,6 +252,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
249252
self.edges = edges
250253
self.entry_points = entry_points
251254
self.state = GraphState()
255+
self.tracer = get_tracer()
252256

253257
def execute(self, task: str | list[ContentBlock]) -> GraphResult:
254258
"""Execute task synchronously."""
@@ -274,19 +278,20 @@ async def execute_async(self, task: str | list[ContentBlock]) -> GraphResult:
274278
)
275279

276280
start_time = time.time()
277-
try:
278-
await self._execute_graph()
279-
self.state.status = Status.COMPLETED
280-
logger.debug("status=<%s> | graph execution completed", self.state.status)
281-
282-
except Exception:
283-
logger.exception("graph execution failed")
284-
self.state.status = Status.FAILED
285-
raise
286-
finally:
287-
self.state.execution_time = round((time.time() - start_time) * 1000)
288-
289-
return self._build_result()
281+
span = self.tracer.start_multiagent_span(task, "graph")
282+
with trace_api.use_span(span, end_on_exit=True):
283+
try:
284+
await self._execute_graph()
285+
self.state.status = Status.COMPLETED
286+
logger.debug("status=<%s> | graph execution completed", self.state.status)
287+
288+
except Exception:
289+
logger.exception("graph execution failed")
290+
self.state.status = Status.FAILED
291+
raise
292+
finally:
293+
self.state.execution_time = round((time.time() - start_time) * 1000)
294+
return self._build_result()
290295

291296
async def _execute_graph(self) -> None:
292297
"""Unified execution flow with conditional routing."""

src/strands/session/file_session_manager.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import shutil
77
import tempfile
8-
from dataclasses import asdict
98
from typing import Any, Optional, cast
109

1110
from ..types.exceptions import SessionException
@@ -57,22 +56,18 @@ def _get_agent_path(self, session_id: str, agent_id: str) -> str:
5756
session_path = self._get_session_path(session_id)
5857
return os.path.join(session_path, "agents", f"{AGENT_PREFIX}{agent_id}")
5958

60-
def _get_message_path(self, session_id: str, agent_id: str, message_id: str, timestamp: str) -> str:
59+
def _get_message_path(self, session_id: str, agent_id: str, message_id: int) -> str:
6160
"""Get message file path.
6261
6362
Args:
6463
session_id: ID of the session
6564
agent_id: ID of the agent
66-
message_id: ID of the message
67-
timestamp: ISO format timestamp to include in filename for sorting
65+
message_id: Index of the message
6866
Returns:
6967
The filename for the message
7068
"""
7169
agent_path = self._get_agent_path(session_id, agent_id)
72-
# Use timestamp for sortable filenames
73-
# Replace colons and periods in ISO format with underscores for filesystem compatibility
74-
filename_timestamp = timestamp.replace(":", "_").replace(".", "_")
75-
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{filename_timestamp}_{message_id}.json")
70+
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
7671

7772
def _read_file(self, path: str) -> dict[str, Any]:
7873
"""Read JSON file."""
@@ -100,7 +95,7 @@ def create_session(self, session: Session) -> Session:
10095

10196
# Write session file
10297
session_file = os.path.join(session_dir, "session.json")
103-
session_dict = asdict(session)
98+
session_dict = session.to_dict()
10499
self._write_file(session_file, session_dict)
105100

106101
return session
@@ -123,7 +118,7 @@ def create_agent(self, session_id: str, session_agent: SessionAgent) -> None:
123118
os.makedirs(os.path.join(agent_dir, "messages"), exist_ok=True)
124119

125120
agent_file = os.path.join(agent_dir, "agent.json")
126-
session_data = asdict(session_agent)
121+
session_data = session_agent.to_dict()
127122
self._write_file(agent_file, session_data)
128123

129124
def delete_session(self, session_id: str) -> None:
@@ -152,34 +147,25 @@ def update_agent(self, session_id: str, session_agent: SessionAgent) -> None:
152147

153148
session_agent.created_at = previous_agent.created_at
154149
agent_file = os.path.join(self._get_agent_path(session_id, agent_id), "agent.json")
155-
self._write_file(agent_file, asdict(session_agent))
150+
self._write_file(agent_file, session_agent.to_dict())
156151

157152
def create_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
158153
"""Create a new message for the agent."""
159154
message_file = self._get_message_path(
160155
session_id,
161156
agent_id,
162157
session_message.message_id,
163-
session_message.created_at,
164158
)
165-
session_dict = asdict(session_message)
159+
session_dict = session_message.to_dict()
166160
self._write_file(message_file, session_dict)
167161

168-
def read_message(self, session_id: str, agent_id: str, message_id: str) -> Optional[SessionMessage]:
162+
def read_message(self, session_id: str, agent_id: str, message_id: int) -> Optional[SessionMessage]:
169163
"""Read message data."""
170-
# Get the messages directory
171-
messages_dir = os.path.join(self._get_agent_path(session_id, agent_id), "messages")
172-
if not os.path.exists(messages_dir):
164+
message_path = self._get_message_path(session_id, agent_id, message_id)
165+
if not os.path.exists(message_path):
173166
return None
174-
175-
# List files in messages directory, and check if the filename ends with the message id
176-
for filename in os.listdir(messages_dir):
177-
if filename.endswith(f"{message_id}.json"):
178-
file_path = os.path.join(messages_dir, filename)
179-
message_data = self._read_file(file_path)
180-
return SessionMessage.from_dict(message_data)
181-
182-
return None
167+
message_data = self._read_file(message_path)
168+
return SessionMessage.from_dict(message_data)
183169

184170
def update_message(self, session_id: str, agent_id: str, session_message: SessionMessage) -> None:
185171
"""Update message data."""
@@ -190,8 +176,8 @@ def update_message(self, session_id: str, agent_id: str, session_message: Sessio
190176

191177
# Preserve the original created_at timestamp
192178
session_message.created_at = previous_message.created_at
193-
message_file = self._get_message_path(session_id, agent_id, message_id, session_message.created_at)
194-
self._write_file(message_file, asdict(session_message))
179+
message_file = self._get_message_path(session_id, agent_id, message_id)
180+
self._write_file(message_file, session_message.to_dict())
195181

196182
def list_messages(
197183
self, session_id: str, agent_id: str, limit: Optional[int] = None, offset: int = 0
@@ -201,14 +187,16 @@ def list_messages(
201187
if not os.path.exists(messages_dir):
202188
raise SessionException(f"Messages directory missing from agent: {agent_id} in session {session_id}")
203189

204-
# Read all message files
205-
message_files: list[str] = []
190+
# Read all message files, and record the index
191+
message_index_files: list[tuple[int, str]] = []
206192
for filename in os.listdir(messages_dir):
207193
if filename.startswith(MESSAGE_PREFIX) and filename.endswith(".json"):
208-
message_files.append(filename)
194+
# Extract index from message_<index>.json format
195+
index = int(filename[len(MESSAGE_PREFIX) : -5]) # Remove prefix and .json suffix
196+
message_index_files.append((index, filename))
209197

210-
# Sort filenames - the timestamp in the file's name will sort chronologically
211-
message_files.sort()
198+
# Sort by index and extract just the filenames
199+
message_files = [f for _, f in sorted(message_index_files)]
212200

213201
# Apply pagination to filenames
214202
if limit is not None:

src/strands/session/repository_session_manager.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,8 @@ def __init__(
4848

4949
self.session = session
5050

51-
# Keep track of the initialized agent id's so that two agents in a session cannot share an id
52-
self._initialized_agent_ids: set[str] = set()
53-
54-
# Keep track of the latest message stored in the session in case we need to redact its content.
55-
self._latest_message: Optional[SessionMessage] = None
51+
# Keep track of the latest message of each agent in case we need to redact it.
52+
self._latest_agent_message: dict[str, Optional[SessionMessage]] = {}
5653

5754
def append_message(self, message: Message, agent: Agent) -> None:
5855
"""Append a message to the agent's session.
@@ -61,8 +58,16 @@ def append_message(self, message: Message, agent: Agent) -> None:
6158
message: Message to add to the agent in the session
6259
agent: Agent to append the message to
6360
"""
64-
self._latest_message = SessionMessage.from_message(message)
65-
self.session_repository.create_message(self.session_id, agent.agent_id, self._latest_message)
61+
# Calculate the next index (0 if this is the first message, otherwise increment the previous index)
62+
latest_agent_message = self._latest_agent_message[agent.agent_id]
63+
if latest_agent_message:
64+
next_index = latest_agent_message.message_id + 1
65+
else:
66+
next_index = 0
67+
68+
session_message = SessionMessage.from_message(message, next_index)
69+
self._latest_agent_message[agent.agent_id] = session_message
70+
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
6671

6772
def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
6873
"""Redact the latest message appended to the session.
@@ -71,10 +76,11 @@ def redact_latest_message(self, redact_message: Message, agent: Agent) -> None:
7176
redact_message: New message to use that contains the redact content
7277
agent: Agent to apply the message redaction to
7378
"""
74-
if self._latest_message is None:
79+
latest_agent_message = self._latest_agent_message[agent.agent_id]
80+
if latest_agent_message is None:
7581
raise SessionException("No message to redact.")
76-
self._latest_message.redact_message = redact_message
77-
return self.session_repository.update_message(self.session_id, agent.agent_id, self._latest_message)
82+
latest_agent_message.redact_message = redact_message
83+
return self.session_repository.update_message(self.session_id, agent.agent_id, latest_agent_message)
7884

7985
def sync_agent(self, agent: Agent) -> None:
8086
"""Serialize and update the agent into the session repository.
@@ -93,9 +99,9 @@ def initialize(self, agent: Agent) -> None:
9399
Args:
94100
agent: Agent to initialize from the session
95101
"""
96-
if agent.agent_id in self._initialized_agent_ids:
102+
if agent.agent_id in self._latest_agent_message:
97103
raise SessionException("The `agent_id` of an agent must be unique in a session.")
98-
self._initialized_agent_ids.add(agent.agent_id)
104+
self._latest_agent_message[agent.agent_id] = None
99105

100106
session_agent = self.session_repository.read_agent(self.session_id, agent.agent_id)
101107

@@ -108,8 +114,9 @@ def initialize(self, agent: Agent) -> None:
108114

109115
session_agent = SessionAgent.from_agent(agent)
110116
self.session_repository.create_agent(self.session_id, session_agent)
111-
for message in agent.messages:
112-
session_message = SessionMessage.from_message(message)
117+
# Initialize messages with sequential indices
118+
for i, message in enumerate(agent.messages):
119+
session_message = SessionMessage.from_message(message, i)
113120
self.session_repository.create_message(self.session_id, agent.agent_id, session_message)
114121
else:
115122
logger.debug(

0 commit comments

Comments
 (0)