Skip to content

Commit a13c6c2

Browse files
committed
refactor: remove kwargs spread after agent call
1 parent 91fd7f1 commit a13c6c2

File tree

7 files changed

+196
-242
lines changed

7 files changed

+196
-242
lines changed

src/strands/agent/agent.py

Lines changed: 52 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..types.content import ContentBlock, Message, Messages
3232
from ..types.exceptions import ContextWindowOverflowException
3333
from ..types.models import Model
34-
from ..types.tools import ToolConfig
34+
from ..types.tools import ToolConfig, ToolResult, ToolUse
3535
from ..types.traces import AttributeValue
3636
from .agent_result import AgentResult
3737
from .conversation_manager import (
@@ -97,104 +97,45 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
9797
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
9898
"""
9999

100-
def find_normalized_tool_name() -> Optional[str]:
101-
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
102-
tool_registry = self._agent.tool_registry.registry
103-
104-
if tool_registry.get(name, None):
105-
return name
106-
107-
# If the desired name contains underscores, it might be a placeholder for characters that can't be
108-
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
109-
# all tools that can be represented with the normalized name
110-
if "_" in name:
111-
filtered_tools = [
112-
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
113-
]
114-
115-
# The registry itself defends against similar names, so we can just take the first match
116-
if filtered_tools:
117-
return filtered_tools[0]
118-
119-
raise AttributeError(f"Tool '{name}' not found")
120-
121-
def caller(**kwargs: Any) -> Any:
100+
def caller(user_message_override: Optional[str] = None, **kwargs: Any) -> Any:
122101
"""Call a tool directly by name.
123102
124103
Args:
104+
user_message_override: Optional custom message to record instead of default
125105
**kwargs: Keyword arguments to pass to the tool.
126106
127-
- user_message_override: Custom message to record instead of default
128-
- tool_execution_handler: Custom handler for tool execution
129-
- event_loop_metrics: Custom metrics collector
130-
- messages: Custom message history to use
131-
- tool_config: Custom tool configuration
132-
- callback_handler: Custom callback handler
133-
- record_direct_tool_call: Whether to record this call in history
134-
135107
Returns:
136108
The result returned by the tool.
137109
138110
Raises:
139111
AttributeError: If the tool doesn't exist.
140112
"""
141-
normalized_name = find_normalized_tool_name()
113+
normalized_name = self._find_normalized_tool_name(name)
142114

143115
# Create unique tool ID and set up the tool request
144116
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
145-
tool_use = {
117+
tool_use: ToolUse = {
146118
"toolUseId": tool_id,
147119
"name": normalized_name,
148120
"input": kwargs.copy(),
149121
}
150122

151-
# Extract tool execution parameters
152-
user_message_override = kwargs.get("user_message_override", None)
153-
tool_execution_handler = kwargs.get("tool_execution_handler", self._agent.thread_pool_wrapper)
154-
event_loop_metrics = kwargs.get("event_loop_metrics", self._agent.event_loop_metrics)
155-
messages = kwargs.get("messages", self._agent.messages)
156-
tool_config = kwargs.get("tool_config", self._agent.tool_config)
157-
callback_handler = kwargs.get("callback_handler", self._agent.callback_handler)
158-
record_direct_tool_call = kwargs.get("record_direct_tool_call", self._agent.record_direct_tool_call)
159-
160-
# Process tool call
161-
handler_kwargs = {
162-
k: v
163-
for k, v in kwargs.items()
164-
if k
165-
not in [
166-
"tool_execution_handler",
167-
"event_loop_metrics",
168-
"messages",
169-
"tool_config",
170-
"callback_handler",
171-
"tool_handler",
172-
"system_prompt",
173-
"model",
174-
"model_id",
175-
"user_message_override",
176-
"agent",
177-
"record_direct_tool_call",
178-
]
179-
}
180-
181123
# Execute the tool
182124
tool_result = self._agent.tool_handler.process(
183125
tool=tool_use,
184126
model=self._agent.model,
185127
system_prompt=self._agent.system_prompt,
186-
messages=messages,
187-
tool_config=tool_config,
188-
callback_handler=callback_handler,
189-
tool_execution_handler=tool_execution_handler,
190-
event_loop_metrics=event_loop_metrics,
191-
agent=self._agent,
192-
**handler_kwargs,
128+
messages=self._agent.messages,
129+
tool_config=self._agent.tool_config,
130+
callback_handler=self._agent.callback_handler,
131+
kwargs=kwargs,
193132
)
194133

195-
if record_direct_tool_call:
134+
if self._agent.record_direct_tool_call:
196135
# Create a record of this tool execution in the message history
197-
self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages)
136+
self._agent._record_tool_execution(
137+
tool_use, tool_result, user_message_override, self._agent.messages
138+
)
198139

199140
# Apply window management
200141
self._agent.conversation_manager.apply_management(self._agent)
@@ -203,6 +144,27 @@ def caller(**kwargs: Any) -> Any:
203144

204145
return caller
205146

147+
def _find_normalized_tool_name(self, name: str) -> str:
148+
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
149+
tool_registry = self._agent.tool_registry.registry
150+
151+
if tool_registry.get(name, None):
152+
return name
153+
154+
# If the desired name contains underscores, it might be a placeholder for characters that can't be
155+
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
156+
# all tools that can be represented with the normalized name
157+
if "_" in name:
158+
filtered_tools = [
159+
tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name
160+
]
161+
162+
# The registry itself defends against similar names, so we can just take the first match
163+
if filtered_tools:
164+
return filtered_tools[0]
165+
166+
raise AttributeError(f"Tool '{name}' not found")
167+
206168
def __init__(
207169
self,
208170
model: Union[Model, str, None] = None,
@@ -371,7 +333,7 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
371333
372334
Args:
373335
prompt: The natural language prompt from the user.
374-
**kwargs: Additional parameters to pass to the event loop.
336+
**kwargs: Additional parameters to pass through the event loop.
375337
376338
Returns:
377339
Result object containing:
@@ -514,37 +476,28 @@ def _execute_event_loop_cycle(
514476
Yields:
515477
Events of the loop cycle.
516478
"""
517-
# Extract parameters with fallbacks to instance values
518-
system_prompt = kwargs.pop("system_prompt", self.system_prompt)
519-
model = kwargs.pop("model", self.model)
520-
tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
521-
event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
522-
callback_handler_override = kwargs.pop("callback_handler", callback_handler)
523-
tool_handler = kwargs.pop("tool_handler", self.tool_handler)
524-
messages = kwargs.pop("messages", self.messages)
525-
tool_config = kwargs.pop("tool_config", self.tool_config)
526-
kwargs.pop("agent", None) # Remove agent to avoid conflicts
479+
# Add `Agent` to kwargs to keep backwards-compatibility
480+
kwargs["agent"] = self
527481

528482
try:
529483
# Execute the main event loop cycle
530484
yield from event_loop_cycle(
531-
model=model,
532-
system_prompt=system_prompt,
533-
messages=messages, # will be modified by event_loop_cycle
534-
tool_config=tool_config,
535-
callback_handler=callback_handler_override,
536-
tool_handler=tool_handler,
537-
tool_execution_handler=tool_execution_handler,
538-
event_loop_metrics=event_loop_metrics,
539-
agent=self,
485+
model=self.model,
486+
system_prompt=self.system_prompt,
487+
messages=self.messages, # will be modified by event_loop_cycle
488+
tool_config=self.tool_config,
489+
callback_handler=callback_handler,
490+
tool_handler=self.tool_handler,
491+
tool_execution_handler=self.thread_pool_wrapper,
492+
event_loop_metrics=self.event_loop_metrics,
540493
event_loop_parent_span=self.trace_span,
541-
**kwargs,
494+
kwargs=kwargs,
542495
)
543496

544497
except ContextWindowOverflowException as e:
545498
# Try reducing the context size and retrying
546499
self.conversation_manager.reduce_context(self, e=e)
547-
yield from self._execute_event_loop_cycle(callback_handler_override, kwargs)
500+
yield from self._execute_event_loop_cycle(callback_handler, kwargs)
548501

549502
def _record_tool_execution(
550503
self,
@@ -569,7 +522,7 @@ def _record_tool_execution(
569522
messages: The message history to append to.
570523
"""
571524
# Create user message describing the tool call
572-
user_msg_content = [
525+
user_msg_content: List[ContentBlock] = [
573526
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
574527
]
575528

@@ -578,19 +531,19 @@ def _record_tool_execution(
578531
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
579532

580533
# Create the message sequence
581-
user_msg = {
534+
user_msg: Message = {
582535
"role": "user",
583536
"content": user_msg_content,
584537
}
585-
tool_use_msg = {
538+
tool_use_msg: Message = {
586539
"role": "assistant",
587540
"content": [{"toolUse": tool}],
588541
}
589-
tool_result_msg = {
542+
tool_result_msg: Message = {
590543
"role": "user",
591544
"content": [{"toolResult": tool_result}],
592545
}
593-
assistant_msg = {
546+
assistant_msg: Message = {
594547
"role": "assistant",
595548
"content": [{"text": f"agent.{tool['name']} was called"}],
596549
}

0 commit comments

Comments
 (0)