diff --git a/src/strands/experimental/steering/handlers/__init__.py b/src/strands/experimental/steering/handlers/__init__.py index ca529530f..542126ab5 100644 --- a/src/strands/experimental/steering/handlers/__init__.py +++ b/src/strands/experimental/steering/handlers/__init__.py @@ -1,3 +1,5 @@ """Steering handler implementations.""" -__all__ = [] +from typing import Sequence + +__all__: Sequence[str] = [] diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index c24d91a0d..22feecf32 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -40,10 +40,16 @@ class GeminiConfig(TypedDict, total=False): params: Additional model parameters (e.g., temperature). For a complete list of supported parameters, see https://ai.google.dev/api/generate-content#generationconfig. + gemini_tools: Gemini-specific tools that are not FunctionDeclarations + (e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch). + Use the standard tools interface for function calling tools. + For a complete list of supported tools, see + https://ai.google.dev/api/caching#Tool """ model_id: Required[str] params: dict[str, Any] + gemini_tools: list[genai.types.Tool] def __init__( self, @@ -61,6 +67,10 @@ def __init__( validate_config_keys(model_config, GeminiModel.GeminiConfig) self.config = GeminiModel.GeminiConfig(**model_config) + # Validate gemini_tools if provided + if "gemini_tools" in self.config: + self._validate_gemini_tools(self.config["gemini_tools"]) + logger.debug("config=<%s> | initializing", self.config) self.client_args = client_args or {} @@ -72,6 +82,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type: Args: **model_config: Configuration overrides. """ + # Validate gemini_tools if provided + if "gemini_tools" in model_config: + self._validate_gemini_tools(model_config["gemini_tools"]) + self.config.update(model_config) @override @@ -181,7 +195,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge Return: Gemini tool list. """ - return [ + tools = [ genai.types.Tool( function_declarations=[ genai.types.FunctionDeclaration( @@ -193,6 +207,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge ], ), ] + if self.config.get("gemini_tools"): + tools.extend(self.config["gemini_tools"]) + return tools def _format_request_config( self, @@ -451,3 +468,27 @@ async def structured_output( client = genai.Client(**self.client_args).aio response = await client.models.generate_content(**request) yield {"output": output_model.model_validate(response.parsed)} + + @staticmethod + def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None: + """Validate that gemini_tools does not contain FunctionDeclarations. + + Gemini-specific tools should only include tools that cannot be represented + as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse). + Standard function calling tools should use the tools interface instead. + + Args: + gemini_tools: List of Gemini tools to validate + + Raises: + ValueError: If any tool contains function_declarations + """ + for tool in gemini_tools: + # Check if the tool has function_declarations attribute and it's not empty + if hasattr(tool, "function_declarations") and tool.function_declarations: + raise ValueError( + "gemini_tools should not contain FunctionDeclarations. " + "Use the standard tools interface for function calling tools. " + "gemini_tools is reserved for Gemini-specific tools like " + "GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch." + ) diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index a8f5351cc..8e8742f94 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -624,6 +624,89 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath gemini_client.aio.models.generate_content.assert_called_with(**exp_request) +def test_gemini_tools_validation_rejects_function_declarations(model_id): + tool_with_function_declarations = genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + name="test_function", + description="A test function", + ) + ] + ) + + with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"): + GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations]) + + +def test_gemini_tools_validation_allows_non_function_tools(model_id): + tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch()) + + model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search]) + assert "gemini_tools" in model.config + + +def test_gemini_tools_validation_on_update_config(model): + tool_with_function_declarations = genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + name="test_function", + description="A test function", + ) + ] + ) + + with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"): + model.update_config(gemini_tools=[tool_with_function_declarations]) + + +@pytest.mark.asyncio +async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id): + google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch()) + model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool]) + + await anext(model.stream(messages)) + + exp_request = { + "config": { + "tools": [ + {"function_declarations": []}, + {"google_search": {}}, + ] + }, + "contents": [{"parts": [{"text": "test"}], "role": "user"}], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id): + code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution()) + model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool]) + + await anext(model.stream(messages, tool_specs=[tool_spec])) + + exp_request = { + "config": { + "tools": [ + { + "function_declarations": [ + { + "description": tool_spec["description"], + "name": tool_spec["name"], + "parameters_json_schema": tool_spec["inputSchema"]["json"], + } + ] + }, + {"code_execution": {}}, + ] + }, + "contents": [{"parts": [{"text": "test"}], "role": "user"}], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + @pytest.mark.asyncio async def test_stream_handles_non_json_error(gemini_client, model, messages, caplog, alist): error_message = "Invalid API key" diff --git a/tests_integ/models/test_model_gemini.py b/tests_integ/models/test_model_gemini.py index f9da8490c..5643d159e 100644 --- a/tests_integ/models/test_model_gemini.py +++ b/tests_integ/models/test_model_gemini.py @@ -2,6 +2,7 @@ import pydantic import pytest +from google import genai import strands from strands import Agent @@ -21,6 +22,16 @@ def model(): ) +@pytest.fixture +def gemini_tool_model(): + return GeminiModel( + client_args={"api_key": os.getenv("GOOGLE_API_KEY")}, + model_id="gemini-2.5-flash", + params={"temperature": 0.15}, # Lower temperature for consistent test behavior + gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())], + ) + + @pytest.fixture def tools(): @strands.tool @@ -175,3 +186,19 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow tru_color = assistant_agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +def test_agent_with_gemini_code_execution_tool(gemini_tool_model): + system_prompt = "Generate and run code for all calculations" + agent = Agent(model=gemini_tool_model, system_prompt=system_prompt) + # sample prompt taken from https://ai.google.dev/gemini-api/docs/code-execution + result_turn1 = agent( + "What is the sum of the first 50 prime numbers? Generate and run code for the calculation, " + "and make sure you get all 50." + ) + + # NOTE: We don't verify tool history because built-in tools are currently represented in message history + assert "5117" in str(result_turn1) + + result_turn2 = agent("Summarize that into a single number") + assert "5117" in str(result_turn2)