Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,9 @@ def infer_model(model: Model | KnownModelName | str) -> Model:

return OpenAIModel(model_name, provider=provider)
elif provider in ('google-gla', 'google-vertex'):
from .gemini import GeminiModel
from .google import GoogleModel

return GeminiModel(model_name, provider=provider)
return GoogleModel(model_name, provider=provider)
elif provider == 'groq':
from .groq import GroqModel

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
"""
if client is None:
# NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
api_key = api_key or os.environ.get('GOOGLE_API_KEY')
api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')

if vertexai is None: # pragma: lax no cover
vertexai = bool(location or project or credentials)
Expand Down
24 changes: 12 additions & 12 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,6 @@ async def test_google_model_thinking_config(allow_model_requests: None, google_p
assert result.output == snapshot('The capital of France is **Paris**.')


@pytest.mark.skipif(
not os.getenv('CI', False), reason='Requires properly configured local google vertex config to pass'
)
async def test_google_model_vertex_labels(allow_model_requests: None): # pragma: lax no cover
provider = GoogleProvider(location='global', project='pydantic-ai')
model = GoogleModel('gemini-2.0-flash', provider=provider)
settings = GoogleModelSettings(google_labels={'environment': 'test', 'team': 'analytics'})
agent = Agent(model=model, system_prompt='You are a helpful chatbot.', model_settings=settings)
result = await agent.run('What is the capital of France?')
assert result.output == snapshot('The capital of France is Paris.\n')


async def test_google_model_gla_labels_raises_value_error(allow_model_requests: None, google_provider: GoogleProvider):
model = GoogleModel('gemini-2.0-flash', provider=google_provider)
settings = GoogleModelSettings(google_labels={'environment': 'test', 'team': 'analytics'})
Expand Down Expand Up @@ -360,6 +348,18 @@ async def test_google_model_vertex_provider(allow_model_requests: None):
assert result.output == snapshot('The capital of France is Paris.\n')


@pytest.mark.skipif(
not os.getenv('CI', False), reason='Requires properly configured local google vertex config to pass'
)
async def test_google_model_vertex_labels(allow_model_requests: None): # pragma: lax no cover
provider = GoogleProvider(location='global', project='pydantic-ai')
model = GoogleModel('gemini-2.0-flash', provider=provider)
settings = GoogleModelSettings(google_labels={'environment': 'test', 'team': 'analytics'})
agent = Agent(model=model, system_prompt='You are a helpful chatbot.', model_settings=settings)
result = await agent.run('What is the capital of France?')
assert result.output == snapshot('The capital of France is Paris.\n')


async def test_google_model_iter_stream(allow_model_requests: None, google_provider: GoogleProvider):
model = GoogleModel('gemini-2.0-flash', provider=google_provider)
agent = Agent(model=model, system_prompt='You are a helpful chatbot.')
Expand Down
22 changes: 4 additions & 18 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,15 @@

from ..conftest import TestEnv

# TODO(Marcelo): We need to add Vertex AI to the test cases.

TEST_CASES = [
('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
('OPENAI_API_KEY', 'gpt-3.5-turbo', 'gpt-3.5-turbo', 'openai', 'openai', 'OpenAIModel'),
('OPENAI_API_KEY', 'o1', 'o1', 'openai', 'openai', 'OpenAIModel'),
('AZURE_OPENAI_API_KEY', 'azure:gpt-3.5-turbo', 'gpt-3.5-turbo', 'azure', 'azure', 'OpenAIModel'),
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
('GEMINI_API_KEY', 'gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'gemini', 'GeminiModel'),
(
'GEMINI_API_KEY',
'google-vertex:gemini-1.5-flash',
'gemini-1.5-flash',
'google-vertex',
'vertexai',
'GeminiModel',
),
(
'GEMINI_API_KEY',
'vertexai:gemini-1.5-flash',
'gemini-1.5-flash',
'google-vertex',
'vertexai',
'GeminiModel',
),
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'google', 'GoogleModel'),
('GEMINI_API_KEY', 'gemini-1.5-flash', 'gemini-1.5-flash', 'google-gla', 'google', 'GoogleModel'),
(
'ANTHROPIC_API_KEY',
'anthropic:claude-3-5-haiku-latest',
Expand Down
11 changes: 7 additions & 4 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,11 +1536,14 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:


def test_model_requests_blocked(env: TestEnv):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('google-gla:gemini-1.5-flash', output_type=tuple[str, str], defer_model_check=True)
try:
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('google-gla:gemini-1.5-flash', output_type=tuple[str, str], defer_model_check=True)

with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
agent.run_sync('Hello')
with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
agent.run_sync('Hello')
except ImportError: # pragma: lax no cover
pytest.skip('google-genai not installed')


def test_override_model(env: TestEnv):
Expand Down