Skip to content

Commit d33a1a8

Browse files
author
Murat Kaan Meral
committed
Merge branch 'main' into interface-kwargs-update
2 parents f00efcd + 19db55c commit d33a1a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2301
-979
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ __pycache__*
88
.ruff_cache
99
*.bak
1010
.vscode
11-
dist
11+
dist
12+
repl_state

pyproject.toml

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ dev = [
5858
"pre-commit>=3.2.0,<4.2.0",
5959
"pytest>=8.0.0,<9.0.0",
6060
"pytest-asyncio>=0.26.0,<0.27.0",
61+
"pytest-cov>=4.1.0,<5.0.0",
62+
"pytest-xdist>=3.0.0,<4.0.0",
6163
"ruff>=0.4.4,<0.5.0",
6264
]
6365
docs = [
@@ -94,13 +96,59 @@ a2a = [
9496
"fastapi>=0.115.12",
9597
"starlette>=0.46.2",
9698
]
99+
all = [
100+
# anthropic
101+
"anthropic>=0.21.0,<1.0.0",
102+
103+
# dev
104+
"commitizen>=4.4.0,<5.0.0",
105+
"hatch>=1.0.0,<2.0.0",
106+
"moto>=5.1.0,<6.0.0",
107+
"mypy>=1.15.0,<2.0.0",
108+
"pre-commit>=3.2.0,<4.2.0",
109+
"pytest>=8.0.0,<9.0.0",
110+
"pytest-asyncio>=0.26.0,<0.27.0",
111+
"pytest-cov>=4.1.0,<5.0.0",
112+
"pytest-xdist>=3.0.0,<4.0.0",
113+
"ruff>=0.4.4,<0.5.0",
114+
115+
# docs
116+
"sphinx>=5.0.0,<6.0.0",
117+
"sphinx-rtd-theme>=1.0.0,<2.0.0",
118+
"sphinx-autodoc-typehints>=1.12.0,<2.0.0",
119+
120+
# litellm
121+
"litellm>=1.72.6,<1.73.0",
122+
123+
# llama
124+
"llama-api-client>=0.1.0,<1.0.0",
125+
126+
# mistral
127+
"mistralai>=1.8.2",
128+
129+
# ollama
130+
"ollama>=0.4.8,<1.0.0",
131+
132+
# openai
133+
"openai>=1.68.0,<2.0.0",
134+
135+
# otel
136+
"opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
137+
138+
# a2a
139+
"a2a-sdk[sql]>=0.2.11",
140+
"uvicorn>=0.34.2",
141+
"httpx>=0.28.1",
142+
"fastapi>=0.115.12",
143+
"starlette>=0.46.2",
144+
]
97145

98146
[tool.hatch.version]
99147
# Tells Hatch to use your version control system (git) to determine the version.
100148
source = "vcs"
101149

102150
[tool.hatch.envs.hatch-static-analysis]
103-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
151+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
104152
dependencies = [
105153
"mypy>=1.15.0,<2.0.0",
106154
"ruff>=0.11.6,<0.12.0",
@@ -116,15 +164,14 @@ format-fix = [
116164
]
117165
lint-check = [
118166
"ruff check",
119-
# excluding due to A2A and OTEL http exporter dependency conflict
120-
"mypy -p src --exclude src/strands/multiagent"
167+
"mypy -p src"
121168
]
122169
lint-fix = [
123170
"ruff check --fix"
124171
]
125172

126173
[tool.hatch.envs.hatch-test]
127-
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer"]
174+
features = ["anthropic", "litellm", "llamaapi", "ollama", "openai", "otel", "mistral", "writer", "a2a"]
128175
extra-dependencies = [
129176
"moto>=5.1.0,<6.0.0",
130177
"pytest>=8.0.0,<9.0.0",
@@ -140,35 +187,17 @@ extra-args = [
140187

141188
[tool.hatch.envs.dev]
142189
dev-mode = true
143-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer"]
144-
145-
[tool.hatch.envs.a2a]
146-
dev-mode = true
147-
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "a2a"]
148-
149-
[tool.hatch.envs.a2a.scripts]
150-
run = [
151-
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a {args}"
152-
]
153-
run-cov = [
154-
"pytest{env:HATCH_TEST_ARGS:} tests/strands/multiagent/a2a --cov --cov-config=pyproject.toml {args}"
155-
]
156-
lint-check = [
157-
"ruff check",
158-
"mypy -p src/strands/multiagent/a2a"
159-
]
190+
features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama", "otel", "mistral", "writer", "a2a"]
160191

161192
[[tool.hatch.envs.hatch-test.matrix]]
162193
python = ["3.13", "3.12", "3.11", "3.10"]
163194

164195
[tool.hatch.envs.hatch-test.scripts]
165196
run = [
166-
# excluding due to A2A and OTEL http exporter dependency conflict
167-
"pytest{env:HATCH_TEST_ARGS:} {args} --ignore=tests/strands/multiagent/a2a"
197+
"pytest{env:HATCH_TEST_ARGS:} {args}"
168198
]
169199
run-cov = [
170-
# excluding due to A2A and OTEL http exporter dependency conflict
171-
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args} --ignore=tests/strands/multiagent/a2a"
200+
"pytest{env:HATCH_TEST_ARGS:} --cov --cov-config=pyproject.toml {args}"
172201
]
173202

174203
cov-combine = []
@@ -203,10 +232,6 @@ prepare = [
203232
"hatch run test-lint",
204233
"hatch test --all"
205234
]
206-
test-a2a = [
207-
# required to run manually due to A2A and OTEL http exporter dependency conflict
208-
"hatch -e a2a run run {args}"
209-
]
210235

211236
[tool.mypy]
212237
python_version = "3.10"

src/strands/agent/agent.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,29 @@
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
1717
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
18+
from uuid import uuid4
1819

1920
from opentelemetry import trace
2021
from pydantic import BaseModel
2122

2223
from ..event_loop.event_loop import event_loop_cycle, run_tool
23-
from ..experimental.hooks import (
24+
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
25+
from ..hooks import (
2426
AfterInvocationEvent,
2527
AgentInitializedEvent,
2628
BeforeInvocationEvent,
29+
HookProvider,
2730
HookRegistry,
2831
MessageAddedEvent,
2932
)
30-
from ..handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
3133
from ..models.bedrock import BedrockModel
34+
from ..models.model import Model
3235
from ..telemetry.metrics import EventLoopMetrics
3336
from ..telemetry.tracer import get_tracer
3437
from ..tools.registry import ToolRegistry
3538
from ..tools.watcher import ToolWatcher
3639
from ..types.content import ContentBlock, Message, Messages
3740
from ..types.exceptions import ContextWindowOverflowException
38-
from ..types.models import Model
3941
from ..types.tools import ToolResult, ToolUse
4042
from ..types.traces import AttributeValue
4143
from .agent_result import AgentResult
@@ -200,9 +202,11 @@ def __init__(
200202
load_tools_from_directory: bool = True,
201203
trace_attributes: Optional[Mapping[str, AttributeValue]] = None,
202204
*,
205+
agent_id: Optional[str] = None,
203206
name: Optional[str] = None,
204207
description: Optional[str] = None,
205208
state: Optional[Union[AgentState, dict]] = None,
209+
hooks: Optional[list[HookProvider]] = None,
206210
):
207211
"""Initialize the Agent with the specified configuration.
208212
@@ -233,17 +237,24 @@ def __init__(
233237
load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory.
234238
Defaults to True.
235239
trace_attributes: Custom trace attributes to apply to the agent's trace span.
240+
agent_id: Optional ID for the agent, useful for multi-agent scenarios.
241+
If None, a UUID is generated.
236242
name: name of the Agent
237243
Defaults to None.
238244
description: description of what the Agent does
239245
Defaults to None.
240246
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
241247
Defaults to an empty AgentState object.
248+
hooks: hooks to be added to the agent hook registry
249+
Defaults to None.
242250
"""
243251
self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model
244252
self.messages = messages if messages is not None else []
245253

246254
self.system_prompt = system_prompt
255+
self.agent_id = agent_id or str(uuid4())
256+
self.name = name or _DEFAULT_AGENT_NAME
257+
self.description = description
247258

248259
# If not provided, create a new PrintingCallbackHandler instance
249260
# If explicitly set to None, use null_callback_handler
@@ -299,12 +310,12 @@ def __init__(
299310
self.state = AgentState()
300311

301312
self.tool_caller = Agent.ToolCaller(self)
302-
self.name = name or _DEFAULT_AGENT_NAME
303-
self.description = description
304313

305-
self._hooks = HookRegistry()
306-
# Register built-in hook providers (like ConversationManager) here
307-
self._hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
314+
self.hooks = HookRegistry()
315+
if hooks:
316+
for hook in hooks:
317+
self.hooks.add_hook(hook)
318+
self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self))
308319

309320
@property
310321
def tool(self) -> ToolCaller:
@@ -425,7 +436,7 @@ async def structured_output_async(
425436
Raises:
426437
ValueError: If no conversation history or prompt is provided.
427438
"""
428-
self._hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
439+
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
429440

430441
try:
431442
if not self.messages and not prompt:
@@ -444,7 +455,7 @@ async def structured_output_async(
444455
return event["output"]
445456

446457
finally:
447-
self._hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
458+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
448459

449460
async def stream_async(self, prompt: Union[str, list[ContentBlock]], **kwargs: Any) -> AsyncIterator[Any]:
450461
"""Process a natural language prompt and yield events as an async iterator.
@@ -512,7 +523,7 @@ async def _run_loop(
512523
Yields:
513524
Events from the event loop cycle.
514525
"""
515-
self._hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
526+
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
516527

517528
try:
518529
yield {"callback": {"init_event_loop": True, **invocation_state}}
@@ -526,7 +537,7 @@ async def _run_loop(
526537

527538
finally:
528539
self.conversation_manager.apply_management(self)
529-
self._hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
540+
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
530541

531542
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
532543
"""Execute the event loop cycle with retry logic for context window limits.
@@ -656,4 +667,4 @@ def _end_agent_trace_span(
656667
def _append_message(self, message: Message) -> None:
657668
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
658669
self.messages.append(message)
659-
self._hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))
670+
self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message))

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
7676

7777
if len(messages) <= self.window_size:
7878
logger.debug(
79-
"window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size
79+
"message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size
8080
)
8181
return
8282
self.reduce_context(agent)

src/strands/event_loop/event_loop.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
AfterToolInvocationEvent,
1919
BeforeModelInvocationEvent,
2020
BeforeToolInvocationEvent,
21+
)
22+
from ..hooks import (
2123
MessageAddedEvent,
22-
get_registry,
2324
)
2425
from ..telemetry.metrics import Trace
2526
from ..telemetry.tracer import get_tracer
@@ -120,7 +121,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
120121

121122
tool_specs = agent.tool_registry.get_all_tool_specs()
122123

123-
get_registry(agent).invoke_callbacks(
124+
agent.hooks.invoke_callbacks(
124125
BeforeModelInvocationEvent(
125126
agent=agent,
126127
)
@@ -139,7 +140,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
139140
stop_reason, message, usage, metrics = event["stop"]
140141
invocation_state.setdefault("request_state", {})
141142

142-
get_registry(agent).invoke_callbacks(
143+
agent.hooks.invoke_callbacks(
143144
AfterModelInvocationEvent(
144145
agent=agent,
145146
stop_response=AfterModelInvocationEvent.ModelStopResponse(
@@ -157,7 +158,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
157158
if model_invoke_span:
158159
tracer.end_span_with_error(model_invoke_span, str(e), e)
159160

160-
get_registry(agent).invoke_callbacks(
161+
agent.hooks.invoke_callbacks(
161162
AfterModelInvocationEvent(
162163
agent=agent,
163164
exception=e,
@@ -191,7 +192,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
191192

192193
# Add the response message to the conversation
193194
agent.messages.append(message)
194-
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
195+
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
195196
yield {"callback": {"message": message}}
196197

197198
# Update metrics
@@ -311,7 +312,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str
311312
}
312313
)
313314

314-
before_event = get_registry(agent).invoke_callbacks(
315+
before_event = agent.hooks.invoke_callbacks(
315316
BeforeToolInvocationEvent(
316317
agent=agent,
317318
selected_tool=tool_func,
@@ -346,7 +347,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str
346347
"content": [{"text": f"Unknown tool: {tool_name}"}],
347348
}
348349
# for every Before event call, we need to have an AfterEvent call
349-
after_event = get_registry(agent).invoke_callbacks(
350+
after_event = agent.hooks.invoke_callbacks(
350351
AfterToolInvocationEvent(
351352
agent=agent,
352353
selected_tool=selected_tool,
@@ -363,7 +364,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str
363364

364365
result = event
365366

366-
after_event = get_registry(agent).invoke_callbacks(
367+
after_event = agent.hooks.invoke_callbacks(
367368
AfterToolInvocationEvent(
368369
agent=agent,
369370
selected_tool=selected_tool,
@@ -381,7 +382,7 @@ async def run_tool(agent: "Agent", tool_use: ToolUse, invocation_state: dict[str
381382
"status": "error",
382383
"content": [{"text": f"Error: {str(e)}"}],
383384
}
384-
after_event = get_registry(agent).invoke_callbacks(
385+
after_event = agent.hooks.invoke_callbacks(
385386
AfterToolInvocationEvent(
386387
agent=agent,
387388
selected_tool=selected_tool,
@@ -458,7 +459,7 @@ def tool_handler(tool_use: ToolUse) -> ToolGenerator:
458459
}
459460

460461
agent.messages.append(tool_result_message)
461-
get_registry(agent).invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
462+
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message))
462463
yield {"callback": {"message": tool_result_message}}
463464

464465
if cycle_span:

src/strands/event_loop/streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import logging
55
from typing import Any, AsyncGenerator, AsyncIterable, Optional
66

7+
from ..models.model import Model
78
from ..types.content import ContentBlock, Message, Messages
8-
from ..types.models import Model
99
from ..types.streaming import (
1010
ContentBlockDeltaEvent,
1111
ContentBlockStart,

0 commit comments

Comments
 (0)