diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index 530d008aa209..789d21b4ce79 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -80,7 +80,5 @@ def handle_non_stream_rest_response(response: HttpResponse) -> None: """ try: response.read() + finally: response.close() - except Exception as exc: - response.close() - raise exc diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index bc23c202eaea..7c0d201fd61d 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -67,7 +67,5 @@ async def handle_no_stream_rest_response(response: "RestAsyncHttpResponse") -> N """ try: await response.read() + finally: await response.close() - except Exception as exc: - await response.close() - raise exc diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 64ce963a1229..9c653b34d55b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -474,7 +474,7 @@ async def __anext__(self): except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err except asyncio.TimeoutError as err: - raise ServiceResponseError(err, error=err) from err + raise ServiceResponseTimeoutError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: raise ServiceRequestError(err, error=err) from err except Exception as err: @@ -571,7 +571,7 @@ async def load_body(self) -> None: except aiohttp.client_exceptions.ClientResponseError as err: raise ServiceResponseError(err, error=err) from err except asyncio.TimeoutError as err: - raise ServiceResponseError(err, error=err) from err + raise ServiceResponseTimeoutError(err, error=err) from err except aiohttp.client_exceptions.ClientError as err: raise ServiceRequestError(err, error=err) from err diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index 00637cfd59db..9f102f4b0b20 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -46,7 +46,9 @@ from azure.core.configuration import ConnectionConfiguration from azure.core.exceptions import ( ServiceRequestError, + ServiceRequestTimeoutError, ServiceResponseError, + ServiceResponseTimeoutError, IncompleteReadError, HttpResponseError, DecodeError, @@ -85,7 +87,7 @@ def _read_raw_stream(response, chunk_size=1): except CoreDecodeError as e: raise DecodeError(e, error=e) from e except ReadTimeoutError as e: - raise ServiceRequestError(e, error=e) from e + raise ServiceResponseTimeoutError(e, error=e) from e else: # Standard file-like object. while True: @@ -202,6 +204,14 @@ def __next__(self): _LOGGER.warning("Unable to stream download.") internal_response.close() raise HttpResponseError(err, error=err) from err + except requests.ConnectionError as err: + internal_response.close() + if err.args and isinstance(err.args[0], ReadTimeoutError): + raise ServiceResponseTimeoutError(err, error=err) from err + raise ServiceResponseError(err, error=err) from err + except requests.RequestException as err: + internal_response.close() + raise ServiceResponseError(err, error=err) from err except Exception as err: _LOGGER.warning("Unable to stream download.") internal_response.close() @@ -384,13 +394,14 @@ def send( # pylint: disable=too-many-statements "Please report this issue to https://github.com/Azure/azure-sdk-for-python/issues." ) from err raise - except ( - NewConnectionError, - ConnectTimeoutError, - ) as err: + except NewConnectionError as err: error = ServiceRequestError(err, error=err) + except ConnectTimeoutError as err: + error = ServiceRequestTimeoutError(err, error=err) + except requests.exceptions.ConnectTimeout as err: + error = ServiceRequestTimeoutError(err, error=err) except requests.exceptions.ReadTimeout as err: - error = ServiceResponseError(err, error=err) + error = ServiceResponseTimeoutError(err, error=err) except requests.exceptions.ConnectionError as err: if err.args and isinstance(err.args[0], ProtocolError): error = ServiceResponseError(err, error=err) @@ -405,7 +416,7 @@ def send( # pylint: disable=too-many-statements _LOGGER.warning("Unable to stream download.") error = HttpResponseError(err, error=err) except requests.RequestException as err: - error = ServiceRequestError(err, error=err) + error = ServiceResponseError(err, error=err) if error: raise error diff --git a/sdk/core/azure-core/azure/core/rest/_aiohttp.py b/sdk/core/azure-core/azure/core/rest/_aiohttp.py index f173a8957456..ebcae4844872 100644 --- a/sdk/core/azure-core/azure/core/rest/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/rest/_aiohttp.py @@ -27,14 +27,23 @@ import asyncio # pylint: disable=do-not-import-asyncio from itertools import groupby from typing import Iterator, cast + +import aiohttp from multidict import CIMultiDict + from ._http_response_impl_async import ( AsyncHttpResponseImpl, AsyncHttpResponseBackcompatMixin, ) from ..pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator from ..utils._pipeline_transport_rest_shared import _pad_attr_name, _aiohttp_body_helper -from ..exceptions import ResponseNotReadError +from ..exceptions import ( + ResponseNotReadError, + IncompleteReadError, + ServiceResponseError, + ServiceResponseTimeoutError, + ServiceRequestError, +) class _ItemsView(collections.abc.ItemsView): @@ -212,7 +221,18 @@ async def read(self) -> bytes: """ if not self._content: self._stream_download_check() - self._content = await self._internal_response.read() + try: + self._content = await self._internal_response.read() + except aiohttp.client_exceptions.ClientPayloadError as err: + # This is the case that server closes connection before we finish the reading. aiohttp library + # raises ClientPayloadError. + raise IncompleteReadError(err, error=err) from err + except aiohttp.client_exceptions.ClientResponseError as err: + raise ServiceResponseError(err, error=err) from err + except asyncio.TimeoutError as err: + raise ServiceResponseTimeoutError(err, error=err) from err + except aiohttp.client_exceptions.ClientError as err: + raise ServiceRequestError(err, error=err) from err await self._set_read_checks() return _aiohttp_body_helper(self) diff --git a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py index f0e11b1d213c..63e20a2b0562 100644 --- a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py @@ -3,6 +3,15 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- + +import pytest +import sys +import asyncio +from packaging.version import Version +from unittest import mock + +import aiohttp + from azure.core.pipeline.transport import ( AsyncHttpResponse as PipelineTransportAsyncHttpResponse, AsyncHttpTransport, @@ -21,13 +30,8 @@ ServiceRequestTimeoutError, ServiceResponseTimeoutError, ) + from utils import HTTP_REQUESTS, request_and_responses_product -import pytest -import sys -import asyncio -from unittest.mock import Mock -from packaging.version import Version -import aiohttp # transport = mock.MagicMock(spec=AsyncHttpTransport) @@ -1049,47 +1053,66 @@ async def test_close_too_soon_works_fine(caplog, port, http_request): assert result # No exception is good enough here -@pytest.mark.skipif( - Version(aiohttp.__version__) >= Version("3.10"), - reason="aiohttp 3.10 introduced separate connection timeout", -) @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @pytest.mark.asyncio -async def test_aiohttp_timeout_response(http_request): +async def test_aiohttp_timeout_response(port, http_request): async with AioHttpTransport() as transport: - transport.session._connector.connect = Mock(side_effect=asyncio.TimeoutError("Too slow!")) - request = http_request("GET", f"http://localhost:12345/basic/string") + request = http_request("GET", f"http://localhost:{port}/basic/string") - with pytest.raises(ServiceResponseTimeoutError) as err: - await transport.send(request) + with mock.patch.object( + aiohttp.ClientResponse, "start", side_effect=asyncio.TimeoutError("Too slow!") + ) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(request) - with pytest.raises(ServiceResponseError) as err: - await transport.send(request) + with pytest.raises(ServiceResponseError) as err: + await transport.send(request) - stream_request = http_request("GET", f"http://localhost:12345/streams/basic") - with pytest.raises(ServiceResponseTimeoutError) as err: - await transport.send(stream_request, stream=True) + stream_resp = http_request("GET", f"http://localhost:{port}/streams/basic") + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(stream_resp, stream=True) + + stream_resp = await transport.send(stream_resp, stream=True) + with mock.patch.object( + aiohttp.streams.StreamReader, "read", side_effect=asyncio.TimeoutError("Too slow!") + ) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + try: + # current HttpResponse + await stream_resp.read() + except AttributeError: + # legacy HttpResponse + b"".join([b async for b in stream_resp.stream_download(None)]) -@pytest.mark.skipif( - Version(aiohttp.__version__) < Version("3.10"), - reason="aiohttp 3.10 introduced separate connection timeout", -) @pytest.mark.parametrize("http_request", HTTP_REQUESTS) @pytest.mark.asyncio async def test_aiohttp_timeout_request(http_request): async with AioHttpTransport() as transport: - transport.session._connector.connect = Mock(side_effect=asyncio.TimeoutError("Too slow!")) + transport.session._connector.connect = mock.Mock(side_effect=asyncio.TimeoutError("Too slow!")) request = http_request("GET", f"http://localhost:12345/basic/string") - with pytest.raises(ServiceRequestTimeoutError) as err: - await transport.send(request) + # aiohttp 3.10 introduced separate connection timeout + if Version(aiohttp.__version__) >= Version("3.10"): + with pytest.raises(ServiceRequestTimeoutError) as err: + await transport.send(request) + + with pytest.raises(ServiceRequestError) as err: + await transport.send(request) + + stream_request = http_request("GET", f"http://localhost:12345/streams/basic") + with pytest.raises(ServiceRequestTimeoutError) as err: + await transport.send(stream_request, stream=True) + + else: + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(request) - with pytest.raises(ServiceRequestError) as err: - await transport.send(request) + with pytest.raises(ServiceResponseError) as err: + await transport.send(request) - stream_request = http_request("GET", f"http://localhost:12345/streams/basic") - with pytest.raises(ServiceRequestTimeoutError) as err: - await transport.send(stream_request, stream=True) + stream_request = http_request("GET", f"http://localhost:12345/streams/basic") + with pytest.raises(ServiceResponseTimeoutError) as err: + await transport.send(stream_request, stream=True) diff --git a/sdk/core/azure-core/tests/test_basic_transport.py b/sdk/core/azure-core/tests/test_basic_transport.py index 2e75b93d2a96..24d3790f0590 100644 --- a/sdk/core/azure-core/tests/test_basic_transport.py +++ b/sdk/core/azure-core/tests/test_basic_transport.py @@ -5,28 +5,35 @@ # ------------------------------------------------------------------------- from http.client import HTTPConnection from collections import OrderedDict -import sys +import logging +import pytest +from unittest import mock +from socket import timeout as SocketTimeout -try: - from unittest import mock -except ImportError: - import mock +from urllib3.util import connection as urllib_connection +from urllib3.response import HTTPResponse as UrllibResponse +from urllib3.connection import HTTPConnection as UrllibConnection +from azure.core.rest._http_response_impl import HttpResponseImpl as RestHttpResponseImpl +from azure.core.pipeline._tools import is_rest from azure.core.pipeline.transport import HttpResponse as PipelineTransportHttpResponse, RequestsTransport from azure.core.pipeline.transport._base import HttpTransport, _deserialize_response, _urljoin from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import Pipeline -from azure.core.exceptions import HttpResponseError -import logging -import pytest +from azure.core.exceptions import ( + HttpResponseError, + ServiceRequestError, + ServiceResponseError, + ServiceRequestTimeoutError, + ServiceResponseTimeoutError, +) + from utils import ( HTTP_REQUESTS, request_and_responses_product, HTTP_CLIENT_TRANSPORT_RESPONSES, create_transport_response, ) -from azure.core.rest._http_response_impl import HttpResponseImpl as RestHttpResponseImpl -from azure.core.pipeline._tools import is_rest class PipelineTransportMockResponse(PipelineTransportHttpResponse): @@ -1322,3 +1329,49 @@ def test_close_too_soon_works_fine(caplog, port, http_request): result = transport.send(request) assert result # No exception is good enough here + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_requests_timeout_response(caplog, port, http_request): + transport = RequestsTransport() + + request = http_request("GET", f"http://localhost:{port}/basic/string") + + with mock.patch.object(UrllibConnection, "getresponse", side_effect=SocketTimeout) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + transport.send(request, read_timeout=0.0001) + + with pytest.raises(ServiceResponseError) as err: + transport.send(request, read_timeout=0.0001) + + stream_request = http_request("GET", f"http://localhost:{port}/streams/basic") + with pytest.raises(ServiceResponseTimeoutError) as err: + transport.send(stream_request, stream=True, read_timeout=0.0001) + + stream_resp = transport.send(stream_request, stream=True) + with mock.patch.object(UrllibResponse, "_handle_chunk", side_effect=SocketTimeout) as mock_method: + with pytest.raises(ServiceResponseTimeoutError) as err: + try: + # current HttpResponse + stream_resp.read() + except AttributeError: + # legacy HttpResponse + b"".join(stream_resp.stream_download(None)) + + +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_requests_timeout_request(caplog, port, http_request): + transport = RequestsTransport() + + request = http_request("GET", f"http://localhost:{port}/basic/string") + + with mock.patch.object(urllib_connection, "create_connection", side_effect=SocketTimeout) as mock_method: + with pytest.raises(ServiceRequestTimeoutError) as err: + transport.send(request, connection_timeout=0.0001) + + with pytest.raises(ServiceRequestTimeoutError) as err: + transport.send(request, connection_timeout=0.0001) + + stream_request = http_request("GET", f"http://localhost:{port}/streams/basic") + with pytest.raises(ServiceRequestTimeoutError) as err: + transport.send(stream_request, stream=True, connection_timeout=0.0001) diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index eeacb47f077c..e9ea512d0942 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -9,6 +9,7 @@ ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator +from azure.core.exceptions import ServiceResponseError try: from unittest import mock @@ -73,7 +74,7 @@ def close(self): http_response.internal_response = MockInternalResponse() stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch("time.sleep", return_value=None): - with pytest.raises(requests.exceptions.ConnectionError): + with pytest.raises(ServiceResponseError): stream.__next__() @@ -133,5 +134,5 @@ def mock_run(self, *args, **kwargs): pipeline = Pipeline(transport) pipeline.run = mock_run downloader = response.stream_download(pipeline, decompress=False) - with pytest.raises(requests.exceptions.ConnectionError): + with pytest.raises(ServiceResponseError): full_response = b"".join(downloader)