|
10 | 10 | from httpx import Timeout |
11 | 11 | from inline_snapshot import Is, snapshot |
12 | 12 | from pydantic import BaseModel |
| 13 | +from pytest_mock import MockerFixture |
13 | 14 | from typing_extensions import TypedDict |
14 | 15 |
|
15 | 16 | from pydantic_ai import ( |
|
43 | 44 | ) |
44 | 45 | from pydantic_ai.agent import Agent |
45 | 46 | from pydantic_ai.builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool |
46 | | -from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError |
| 47 | +from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError |
47 | 48 | from pydantic_ai.messages import ( |
48 | 49 | BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] |
49 | 50 | BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] |
|
57 | 58 | from ..parts_from_messages import part_types_from_messages |
58 | 59 |
|
59 | 60 | with try_import() as imports_successful: |
| 61 | + from google.genai import errors |
60 | 62 | from google.genai.types import ( |
61 | 63 | FinishReason as GoogleFinishReason, |
62 | 64 | GenerateContentResponse, |
@@ -3514,3 +3516,44 @@ async def test_cache_point_filtering(): |
3514 | 3516 | assert len(content) == 2 |
3515 | 3517 | assert content[0] == {'text': 'text before'} |
3516 | 3518 | assert content[1] == {'text': 'text after'} |
| 3519 | + |
| 3520 | + |
| 3521 | +@pytest.mark.parametrize( |
| 3522 | + 'error_class,error_response,expected_status', |
| 3523 | + [ |
| 3524 | + ( |
| 3525 | + errors.ServerError, |
| 3526 | + {'error': {'code': 503, 'message': 'The service is currently unavailable.', 'status': 'UNAVAILABLE'}}, |
| 3527 | + 503, |
| 3528 | + ), |
| 3529 | + ( |
| 3530 | + errors.ClientError, |
| 3531 | + {'error': {'code': 400, 'message': 'Invalid request parameters', 'status': 'INVALID_ARGUMENT'}}, |
| 3532 | + 400, |
| 3533 | + ), |
| 3534 | + ( |
| 3535 | + errors.ClientError, |
| 3536 | + {'error': {'code': 429, 'message': 'Rate limit exceeded', 'status': 'RESOURCE_EXHAUSTED'}}, |
| 3537 | + 429, |
| 3538 | + ), |
| 3539 | + ], |
| 3540 | +) |
| 3541 | +async def test_google_api_errors_are_handled( |
| 3542 | + allow_model_requests: None, |
| 3543 | + google_provider: GoogleProvider, |
| 3544 | + mocker: MockerFixture, |
| 3545 | + error_class: type[errors.APIError], |
| 3546 | + error_response: dict[str, Any], |
| 3547 | + expected_status: int, |
| 3548 | +): |
| 3549 | + model = GoogleModel('gemini-1.5-flash', provider=google_provider) |
| 3550 | + mocked_error = error_class(expected_status, error_response) |
| 3551 | + mocker.patch.object(model.client.aio.models, 'generate_content', side_effect=mocked_error) |
| 3552 | + |
| 3553 | + agent = Agent(model=model) |
| 3554 | + |
| 3555 | + with pytest.raises(ModelHTTPError) as exc_info: |
| 3556 | + await agent.run('This prompt will trigger the mocked error.') |
| 3557 | + |
| 3558 | + assert exc_info.value.status_code == expected_status |
| 3559 | + assert error_response['error']['message'] in str(exc_info.value.body) |
0 commit comments