Skip to content

Commit 04c8547

Browse files
committed
refactor: remove kwargs spread after agent call
1 parent 3ae8d77 commit 04c8547

File tree

8 files changed

+216
-257
lines changed

8 files changed

+216
-257
lines changed

src/strands/agent/agent.py

Lines changed: 60 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..types.content import ContentBlock, Message, Messages
3535
from ..types.exceptions import ContextWindowOverflowException
3636
from ..types.models import Model
37-
from ..types.tools import ToolConfig
37+
from ..types.tools import ToolConfig, ToolResult, ToolUse
3838
from ..types.traces import AttributeValue
3939
from .agent_result import AgentResult
4040
from .conversation_manager import (
@@ -100,104 +100,45 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
100100
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
101101
"""
102102

103-
def find_normalized_tool_name() -> Optional[str]:
104-
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
105-
tool_registry = self._agent.tool_registry.registry
106-
107-
if tool_registry.get(name, None):
108-
return name
109-
110-
# If the desired name contains underscores, it might be a placeholder for characters that can't be
111-
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
112-
# all tools that can be represented with the normalized name
113-
if "_" in name:
114-
filtered_tools = [
115-
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
116-
]
117-
118-
# The registry itself defends against similar names, so we can just take the first match
119-
if filtered_tools:
120-
return filtered_tools[0]
121-
122-
raise AttributeError(f"Tool '{name}' not found")
123-
124-
def caller(**kwargs: Any) -> Any:
103+
def caller(user_message_override: Optional[str] = None, **kwargs: Any) -> Any:
125104
"""Call a tool directly by name.
126105
127106
Args:
107+
user_message_override: Optional custom message to record instead of default
128108
**kwargs: Keyword arguments to pass to the tool.
129109
130-
- user_message_override: Custom message to record instead of default
131-
- tool_execution_handler: Custom handler for tool execution
132-
- event_loop_metrics: Custom metrics collector
133-
- messages: Custom message history to use
134-
- tool_config: Custom tool configuration
135-
- callback_handler: Custom callback handler
136-
- record_direct_tool_call: Whether to record this call in history
137-
138110
Returns:
139111
The result returned by the tool.
140112
141113
Raises:
142114
AttributeError: If the tool doesn't exist.
143115
"""
144-
normalized_name = find_normalized_tool_name()
116+
normalized_name = self._find_normalized_tool_name(name)
145117

146118
# Create unique tool ID and set up the tool request
147119
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
148-
tool_use = {
120+
tool_use: ToolUse = {
149121
"toolUseId": tool_id,
150122
"name": normalized_name,
151123
"input": kwargs.copy(),
152124
}
153125

154-
# Extract tool execution parameters
155-
user_message_override = kwargs.get("user_message_override", None)
156-
tool_execution_handler = kwargs.get("tool_execution_handler", self._agent.thread_pool_wrapper)
157-
event_loop_metrics = kwargs.get("event_loop_metrics", self._agent.event_loop_metrics)
158-
messages = kwargs.get("messages", self._agent.messages)
159-
tool_config = kwargs.get("tool_config", self._agent.tool_config)
160-
callback_handler = kwargs.get("callback_handler", self._agent.callback_handler)
161-
record_direct_tool_call = kwargs.get("record_direct_tool_call", self._agent.record_direct_tool_call)
162-
163-
# Process tool call
164-
handler_kwargs = {
165-
k: v
166-
for k, v in kwargs.items()
167-
if k
168-
not in [
169-
"tool_execution_handler",
170-
"event_loop_metrics",
171-
"messages",
172-
"tool_config",
173-
"callback_handler",
174-
"tool_handler",
175-
"system_prompt",
176-
"model",
177-
"model_id",
178-
"user_message_override",
179-
"agent",
180-
"record_direct_tool_call",
181-
]
182-
}
183-
184126
# Execute the tool
185127
tool_result = self._agent.tool_handler.process(
186128
tool=tool_use,
187129
model=self._agent.model,
188130
system_prompt=self._agent.system_prompt,
189-
messages=messages,
190-
tool_config=tool_config,
191-
callback_handler=callback_handler,
192-
tool_execution_handler=tool_execution_handler,
193-
event_loop_metrics=event_loop_metrics,
194-
agent=self._agent,
195-
**handler_kwargs,
131+
messages=self._agent.messages,
132+
tool_config=self._agent.tool_config,
133+
callback_handler=self._agent.callback_handler,
134+
kwargs=kwargs,
196135
)
197136

198-
if record_direct_tool_call:
137+
if self._agent.record_direct_tool_call:
199138
# Create a record of this tool execution in the message history
200-
self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages)
139+
self._agent._record_tool_execution(
140+
tool_use, tool_result, user_message_override, self._agent.messages
141+
)
201142

202143
# Apply window management
203144
self._agent.conversation_manager.apply_management(self._agent)
@@ -206,6 +147,27 @@ def caller(**kwargs: Any) -> Any:
206147

207148
return caller
208149

150+
def _find_normalized_tool_name(self, name: str) -> str:
151+
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
152+
tool_registry = self._agent.tool_registry.registry
153+
154+
if tool_registry.get(name, None):
155+
return name
156+
157+
# If the desired name contains underscores, it might be a placeholder for characters that can't be
158+
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
159+
# all tools that can be represented with the normalized name
160+
if "_" in name:
161+
filtered_tools = [
162+
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
163+
]
164+
165+
# The registry itself defends against similar names, so we can just take the first match
166+
if filtered_tools:
167+
return filtered_tools[0]
168+
169+
raise AttributeError(f"Tool '{name}' not found")
170+
209171
def __init__(
210172
self,
211173
model: Union[Model, str, None] = None,
@@ -374,7 +336,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
374336
375337
Args:
376338
prompt: The natural language prompt from the user.
377-
**kwargs: Additional parameters to pass to the event loop.
339+
**kwargs: Additional parameters to pass through the event loop.
378340
379341
Returns:
380342
Result object containing:
@@ -523,41 +485,36 @@ def _run_loop(
523485
finally:
524486
self.conversation_manager.apply_management(self)
525487

526-
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult:
488+
def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: dict[str, Any]) -> AgentResult:
527489
"""Execute the event loop cycle with retry logic for context window limits.
528490
529491
This internal method handles the execution of the event loop cycle and implements
530492
retry logic for handling context window overflow exceptions by reducing the
531493
conversation context and retrying.
532494
495+
Args:
496+
callback_handler: The callback handler to use for events.
497+
kwargs: Additional parameters to pass through event loop.
498+
533499
Returns:
534500
The result of the event loop cycle.
535501
"""
536-
# Extract parameters with fallbacks to instance values
537-
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
538-
model = kwargs.pop("model", self.model)
539-
tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
540-
event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
541-
callback_handler_override = kwargs.pop("callback_handler", callback_handler)
542-
tool_handler = kwargs.pop("tool_handler", self.tool_handler)
543-
messages = kwargs.pop("messages", self.messages)
544-
tool_config = kwargs.pop("tool_config", self.tool_config)
545-
kwargs.pop("agent", None) # Remove agent to avoid conflicts
502+
# Add `Agent` to kwargs to keep backwards-compatibility
503+
kwargs["agent"] = self
546504

547505
try:
548506
# Execute the main event loop cycle
549507
events = event_loop_cycle(
550-
model=model,
551-
system_prompt=system_prompt,
552-
messages=messages, # will be modified by event_loop_cycle
553-
tool_config=tool_config,
554-
callback_handler=callback_handler_override,
555-
tool_handler=tool_handler,
556-
tool_execution_handler=tool_execution_handler,
557-
event_loop_metrics=event_loop_metrics,
558-
agent=self,
508+
model=self.model,
509+
system_prompt=self.system_prompt,
510+
messages=self.messages, # will be modified by event_loop_cycle
511+
tool_config=self.tool_config,
512+
callback_handler=callback_handler,
513+
tool_handler=self.tool_handler,
514+
tool_execution_handler=self.thread_pool_wrapper,
515+
event_loop_metrics=self.event_loop_metrics,
559516
event_loop_parent_span=self.trace_span,
560-
**kwargs,
517+
kwargs=kwargs,
561518
)
562519
for event in events:
563520
if "callback" in event:
@@ -571,14 +528,14 @@ def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs
571528
# Try reducing the context size and retrying
572529

573530
self.conversation_manager.reduce_context(self, e=e)
574-
return self._execute_event_loop_cycle(callback_handler_override, kwargs)
531+
return self._execute_event_loop_cycle(callback_handler, kwargs)
575532

576533
def _record_tool_execution(
577534
self,
578-
tool: Dict[str, Any],
579-
tool_result: Dict[str, Any],
535+
tool: ToolUse,
536+
tool_result: ToolResult,
580537
user_message_override: Optional[str],
581-
messages: List[Dict[str, Any]],
538+
messages: Messages,
582539
) -> None:
583540
"""Record a tool execution in the message history.
584541
@@ -596,7 +553,7 @@ def _record_tool_execution(
596553
messages: The message history to append to.
597554
"""
598555
# Create user message describing the tool call
599-
user_msg_content = [
556+
user_msg_content: List[ContentBlock] = [
600557
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
601558
]
602559

@@ -605,19 +562,19 @@ def _record_tool_execution(
605562
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
606563

607564
# Create the message sequence
608-
user_msg = {
565+
user_msg: Message = {
609566
"role": "user",
610567
"content": user_msg_content,
611568
}
612-
tool_use_msg = {
569+
tool_use_msg: Message = {
613570
"role": "assistant",
614571
"content": [{"toolUse": tool}],
615572
}
616-
tool_result_msg = {
573+
tool_result_msg: Message = {
617574
"role": "user",
618575
"content": [{"toolResult": tool_result}],
619576
}
620-
assistant_msg = {
577+
assistant_msg: Message = {
621578
"role": "assistant",
622579
"content": [{"text": f"agent.{tool['name']} was called"}],
623580
}

0 commit comments

Comments
 (0)