Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def _initialize_fresh_execution(
requires_async: bool,
*,
system_prompt: Optional[str] = None,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
**kwargs,
) -> _ExecutionContext:
Expand All @@ -278,6 +279,8 @@ def _initialize_fresh_execution(
:param streaming_callback: Optional callback for streaming responses.
:param requires_async: Whether the agent run requires asynchronous execution.
:param system_prompt: System prompt for the agent. If provided, it overrides the default system prompt.
:param generation_kwargs: Additional keyword arguments for chat generator. These parameters will
override the parameters passed during component initialization.
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
When passing tool names, tools are selected from the Agent's originally configured tools.
:param kwargs: Additional data to pass to the State used by the Agent.
Expand All @@ -302,6 +305,8 @@ def _initialize_fresh_execution(
if streaming_callback is not None:
tool_invoker_inputs["streaming_callback"] = streaming_callback
generator_inputs["streaming_callback"] = streaming_callback
if generation_kwargs is not None:
generator_inputs["generation_kwargs"] = generation_kwargs

return _ExecutionContext(
state=state,
Expand Down Expand Up @@ -354,6 +359,7 @@ def _initialize_from_snapshot(
streaming_callback: Optional[StreamingCallbackT],
requires_async: bool,
*,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[ToolsType, list[str]]] = None,
) -> _ExecutionContext:
"""
Expand All @@ -362,6 +368,8 @@ def _initialize_from_snapshot(
:param snapshot: An AgentSnapshot containing the state of a previously saved agent execution.
:param streaming_callback: Optional callback for streaming responses.
:param requires_async: Whether the agent run requires asynchronous execution.
:param generation_kwargs: Additional keyword arguments for chat generator. These parameters will
override the parameters passed during component initialization.
:param tools: Optional list of Tool objects, a Toolset, or list of tool names to use for this run.
When passing tool names, tools are selected from the Agent's originally configured tools.
"""
Expand All @@ -386,6 +394,8 @@ def _initialize_from_snapshot(
if streaming_callback is not None:
tool_invoker_inputs["streaming_callback"] = streaming_callback
generator_inputs["streaming_callback"] = streaming_callback
if generation_kwargs is not None:
generator_inputs["generation_kwargs"] = generation_kwargs

return _ExecutionContext(
state=state,
Expand Down Expand Up @@ -471,6 +481,7 @@ def run( # noqa: PLR0915
messages: list[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
*,
generation_kwargs: Optional[dict[str, Any]] = None,
break_point: Optional[AgentBreakpoint] = None,
snapshot: Optional[AgentSnapshot] = None,
system_prompt: Optional[str] = None,
Expand All @@ -483,6 +494,8 @@ def run( # noqa: PLR0915
:param messages: List of Haystack ChatMessage objects to process.
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
The same callback can be configured to emit tool results when a tool is called.
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will
override the parameters passed during component initialization.
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
for "tool_invoker".
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
Expand Down Expand Up @@ -513,7 +526,11 @@ def run( # noqa: PLR0915

if snapshot:
exe_context = self._initialize_from_snapshot(
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools
snapshot=snapshot,
streaming_callback=streaming_callback,
requires_async=False,
tools=tools,
generation_kwargs=generation_kwargs,
)
else:
exe_context = self._initialize_fresh_execution(
Expand All @@ -522,6 +539,7 @@ def run( # noqa: PLR0915
requires_async=False,
system_prompt=system_prompt,
tools=tools,
generation_kwargs=generation_kwargs,
**kwargs,
)

Expand Down Expand Up @@ -628,6 +646,7 @@ async def run_async(
messages: list[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
*,
generation_kwargs: Optional[dict[str, Any]] = None,
break_point: Optional[AgentBreakpoint] = None,
snapshot: Optional[AgentSnapshot] = None,
system_prompt: Optional[str] = None,
Expand All @@ -644,6 +663,8 @@ async def run_async(
:param messages: List of Haystack ChatMessage objects to process.
:param streaming_callback: An asynchronous callback that will be invoked when a response is streamed from the
LLM. The same callback can be configured to emit tool results when a tool is called.
:param generation_kwargs: Additional keyword arguments for LLM. These parameters will
override the parameters passed during component initialization.
:param break_point: An AgentBreakpoint, can be a Breakpoint for the "chat_generator" or a ToolBreakpoint
for "tool_invoker".
:param snapshot: A dictionary containing a snapshot of a previously saved agent execution. The snapshot contains
Expand Down Expand Up @@ -673,7 +694,11 @@ async def run_async(

if snapshot:
exe_context = self._initialize_from_snapshot(
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools
snapshot=snapshot,
streaming_callback=streaming_callback,
requires_async=True,
tools=tools,
generation_kwargs=generation_kwargs,
)
else:
exe_context = self._initialize_fresh_execution(
Expand All @@ -682,6 +707,7 @@ async def run_async(
requires_async=True,
system_prompt=system_prompt,
tools=tools,
generation_kwargs=generation_kwargs,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Adds `generation_kwargs` to the `Agent` component, allowing for more fine-grained control at run-time over the chat generation.
18 changes: 18 additions & 0 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,24 @@ async def test_run_async_falls_back_to_run_when_chat_generator_has_no_run_async(
assert isinstance(result["last_message"], ChatMessage)
assert result["messages"][-1] == result["last_message"]

@pytest.mark.asyncio
async def test_generation_kwargs(self):
chat_generator = MockChatGeneratorWithoutRunAsync()

agent = Agent(chat_generator=chat_generator)
agent.warm_up()

chat_generator.run = MagicMock(return_value={"replies": [ChatMessage.from_assistant("Hello")]})

await agent.run_async([ChatMessage.from_user("Hello")], generation_kwargs={"temperature": 0.0})

expected_messages = [
ChatMessage(_role=ChatRole.USER, _content=[TextContent(text="Hello")], _name=None, _meta={})
]
chat_generator.run.assert_called_once_with(
messages=expected_messages, generation_kwargs={"temperature": 0.0}, tools=[]
)

@pytest.mark.asyncio
async def test_run_async_uses_chat_generator_run_async_when_available(self, weather_tool):
chat_generator = MockChatGenerator()
Expand Down
Loading