Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/strands/experimental/steering/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Steering handler implementations."""

__all__ = []
from typing import Sequence

__all__: Sequence[str] = []
43 changes: 42 additions & 1 deletion src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@ class GeminiConfig(TypedDict, total=False):
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see
https://ai.google.dev/api/generate-content#generationconfig.
gemini_tools: Gemini-specific tools that are not FunctionDeclarations
(e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch).
Use the standard tools interface for function calling tools.
For a complete list of supported tools, see
https://ai.google.dev/api/caching#Tool
"""

model_id: Required[str]
params: dict[str, Any]
gemini_tools: list[genai.types.Tool]

def __init__(
self,
Expand All @@ -61,6 +67,10 @@ def __init__(
validate_config_keys(model_config, GeminiModel.GeminiConfig)
self.config = GeminiModel.GeminiConfig(**model_config)

# Validate gemini_tools if provided
if "gemini_tools" in self.config:
self._validate_gemini_tools(self.config["gemini_tools"])

logger.debug("config=<%s> | initializing", self.config)

self.client_args = client_args or {}
Expand All @@ -72,6 +82,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
# Validate gemini_tools if provided
if "gemini_tools" in model_config:
self._validate_gemini_tools(model_config["gemini_tools"])

self.config.update(model_config)

@override
Expand Down Expand Up @@ -181,7 +195,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
Return:
Gemini tool list.
"""
return [
tools = [
genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
Expand All @@ -193,6 +207,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
],
),
]
if self.config.get("gemini_tools"):
tools.extend(self.config["gemini_tools"])
return tools

def _format_request_config(
self,
Expand Down Expand Up @@ -451,3 +468,27 @@ async def structured_output(
client = genai.Client(**self.client_args).aio
response = await client.models.generate_content(**request)
yield {"output": output_model.model_validate(response.parsed)}

@staticmethod
def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None:
"""Validate that gemini_tools does not contain FunctionDeclarations.

Gemini-specific tools should only include tools that cannot be represented
as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse).
Standard function calling tools should use the tools interface instead.

Args:
gemini_tools: List of Gemini tools to validate

Raises:
ValueError: If any tool contains function_declarations
"""
for tool in gemini_tools:
# Check if the tool has function_declarations attribute and it's not empty
if hasattr(tool, "function_declarations") and tool.function_declarations:
raise ValueError(
"gemini_tools should not contain FunctionDeclarations. "
"Use the standard tools interface for function calling tools. "
"gemini_tools is reserved for Gemini-specific tools like "
"GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch."
)
83 changes: 83 additions & 0 deletions tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,89 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath
gemini_client.aio.models.generate_content.assert_called_with(**exp_request)


def test_gemini_tools_validation_rejects_function_declarations(model_id):
tool_with_function_declarations = genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
name="test_function",
description="A test function",
)
]
)

with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations])


def test_gemini_tools_validation_allows_non_function_tools(model_id):
tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch())

model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search])
assert "gemini_tools" in model.config


def test_gemini_tools_validation_on_update_config(model):
tool_with_function_declarations = genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
name="test_function",
description="A test function",
)
]
)

with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
model.update_config(gemini_tools=[tool_with_function_declarations])


@pytest.mark.asyncio
async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id):
google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch())
model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool])

await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [
{"function_declarations": []},
{"google_search": {}},
]
},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
"model": model_id,
}
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)


@pytest.mark.asyncio
async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id):
code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution())
model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool])

await anext(model.stream(messages, tool_specs=[tool_spec]))

exp_request = {
"config": {
"tools": [
{
"function_declarations": [
{
"description": tool_spec["description"],
"name": tool_spec["name"],
"parameters_json_schema": tool_spec["inputSchema"]["json"],
}
]
},
{"code_execution": {}},
]
},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
"model": model_id,
}
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)


@pytest.mark.asyncio
async def test_stream_handles_non_json_error(gemini_client, model, messages, caplog, alist):
error_message = "Invalid API key"
Expand Down
27 changes: 27 additions & 0 deletions tests_integ/models/test_model_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pydantic
import pytest
from google import genai

import strands
from strands import Agent
Expand All @@ -21,6 +22,16 @@ def model():
)


@pytest.fixture
def gemini_tool_model():
return GeminiModel(
client_args={"api_key": os.getenv("GOOGLE_API_KEY")},
model_id="gemini-2.5-flash",
params={"temperature": 0.15}, # Lower temperature for consistent test behavior
gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())],
)


@pytest.fixture
def tools():
@strands.tool
Expand Down Expand Up @@ -175,3 +186,19 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow
tru_color = assistant_agent.structured_output(type(yellow_color), content)
exp_color = yellow_color
assert tru_color == exp_color


def test_agent_with_gemini_code_execution_tool(gemini_tool_model):
system_prompt = "Generate and run code for all calculations"
agent = Agent(model=gemini_tool_model, system_prompt=system_prompt)
# sample prompt taken from https://ai.google.dev/gemini-api/docs/code-execution
result_turn1 = agent(
"What is the sum of the first 50 prime numbers? Generate and run code for the calculation, "
"and make sure you get all 50."
)

# NOTE: We don't verify tool history because built-in tools are currently represented in message history
assert "5117" in str(result_turn1)

result_turn2 = agent("Summarize that into a single number")
assert "5117" in str(result_turn2)
Loading