Skip to content

Commit 9301c84

Browse files
Wrap GoogleModel google.genai.errors.APIError in ModelHTTPError so it works with FallbackModel (#3139)
Co-authored-by: Douwe Maan <[email protected]>
1 parent e72170e commit 9301c84

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .._output import OutputObjectDefinition
1515
from .._run_context import RunContext
1616
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool
17-
from ..exceptions import UserError
17+
from ..exceptions import ModelHTTPError, UserError
1818
from ..messages import (
1919
BinaryContent,
2020
BuiltinToolCallPart,
@@ -51,7 +51,7 @@
5151
)
5252

5353
try:
54-
from google.genai import Client
54+
from google.genai import Client, errors
5555
from google.genai.types import (
5656
BlobDict,
5757
CodeExecutionResult,
@@ -394,7 +394,16 @@ async def _generate_content(
394394
) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]:
395395
contents, config = await self._build_content_and_config(messages, model_settings, model_request_parameters)
396396
func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content
397-
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
397+
try:
398+
return await func(model=self._model_name, contents=contents, config=config) # type: ignore
399+
except errors.APIError as e:
400+
if (status_code := e.code) >= 400:
401+
raise ModelHTTPError(
402+
status_code=status_code,
403+
model_name=self._model_name,
404+
body=cast(Any, e.details), # pyright: ignore[reportUnknownMemberType]
405+
) from e
406+
raise # pragma: lax no cover
398407

399408
async def _build_content_and_config(
400409
self,

tests/models/test_google.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from httpx import Timeout
1111
from inline_snapshot import Is, snapshot
1212
from pydantic import BaseModel
13+
from pytest_mock import MockerFixture
1314
from typing_extensions import TypedDict
1415

1516
from pydantic_ai import (
@@ -43,7 +44,7 @@
4344
)
4445
from pydantic_ai.agent import Agent
4546
from pydantic_ai.builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool
46-
from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError
47+
from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError
4748
from pydantic_ai.messages import (
4849
BuiltinToolCallEvent, # pyright: ignore[reportDeprecated]
4950
BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
@@ -57,6 +58,7 @@
5758
from ..parts_from_messages import part_types_from_messages
5859

5960
with try_import() as imports_successful:
61+
from google.genai import errors
6062
from google.genai.types import (
6163
FinishReason as GoogleFinishReason,
6264
GenerateContentResponse,
@@ -3514,3 +3516,44 @@ async def test_cache_point_filtering():
35143516
assert len(content) == 2
35153517
assert content[0] == {'text': 'text before'}
35163518
assert content[1] == {'text': 'text after'}
3519+
3520+
3521+
@pytest.mark.parametrize(
3522+
'error_class,error_response,expected_status',
3523+
[
3524+
(
3525+
errors.ServerError,
3526+
{'error': {'code': 503, 'message': 'The service is currently unavailable.', 'status': 'UNAVAILABLE'}},
3527+
503,
3528+
),
3529+
(
3530+
errors.ClientError,
3531+
{'error': {'code': 400, 'message': 'Invalid request parameters', 'status': 'INVALID_ARGUMENT'}},
3532+
400,
3533+
),
3534+
(
3535+
errors.ClientError,
3536+
{'error': {'code': 429, 'message': 'Rate limit exceeded', 'status': 'RESOURCE_EXHAUSTED'}},
3537+
429,
3538+
),
3539+
],
3540+
)
3541+
async def test_google_api_errors_are_handled(
3542+
allow_model_requests: None,
3543+
google_provider: GoogleProvider,
3544+
mocker: MockerFixture,
3545+
error_class: type[errors.APIError],
3546+
error_response: dict[str, Any],
3547+
expected_status: int,
3548+
):
3549+
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
3550+
mocked_error = error_class(expected_status, error_response)
3551+
mocker.patch.object(model.client.aio.models, 'generate_content', side_effect=mocked_error)
3552+
3553+
agent = Agent(model=model)
3554+
3555+
with pytest.raises(ModelHTTPError) as exc_info:
3556+
await agent.run('This prompt will trigger the mocked error.')
3557+
3558+
assert exc_info.value.status_code == expected_status
3559+
assert error_response['error']['message'] in str(exc_info.value.body)

0 commit comments

Comments
 (0)