From 4bbfab503ed9609f25c6a97779acfd4c70de2f41 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Tue, 22 Aug 2023 16:31:25 -0700 Subject: [PATCH 01/16] Pyright fixes --- sdk/core/azure-core/azure/core/exceptions.py | 8 ++++---- sdk/core/azure-core/azure/core/pipeline/_tools.py | 10 ++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/sdk/core/azure-core/azure/core/exceptions.py b/sdk/core/azure-core/azure/core/exceptions.py index b944cfc044b1..c768bf26bbd6 100644 --- a/sdk/core/azure-core/azure/core/exceptions.py +++ b/sdk/core/azure-core/azure/core/exceptions.py @@ -113,18 +113,18 @@ class _HttpResponseCommonAPI(Protocol): @property def reason(self) -> Optional[str]: - pass + ... @property def status_code(self) -> Optional[int]: - pass + ... def text(self) -> str: - pass + ... @property def request(self) -> object: # object as type, since all we need is str() on it - pass + ... class ErrorMap(Generic[KeyType, ValueType]): diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index 8b375eaa7c64..de90f06f6784 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -23,11 +23,13 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from __future__ import annotations +from typing import TYPE_CHECKING, Union +from typing_extensions import TypeGuard if TYPE_CHECKING: from typing import Any - from azure.core.rest import HttpResponse as RestHttpResponse + from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse def await_result(func, *args, **kwargs): @@ -47,7 +49,7 @@ def await_result(func, *args, **kwargs): return result -def is_rest(obj) -> bool: +def is_rest(obj: object) -> TypeGuard[Union[HttpRequest, HttpResponse, AsyncHttpResponse]]: """Return whether a request or a response is a rest request / response. Checking whether the response has the object content can sometimes result @@ -63,7 +65,7 @@ def is_rest(obj) -> bool: return hasattr(obj, "is_stream_consumed") or hasattr(obj, "content") -def handle_non_stream_rest_response(response: "RestHttpResponse") -> None: +def handle_non_stream_rest_response(response: HttpResponse) -> None: """Handle reading and closing of non stream rest responses. For our new rest responses, we have to call .read() and .close() for our non-stream responses. This way, we load in the body for users to access. From fc300beeb0593ec7ccd0a7332aea236d28e1ad00 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Mon, 28 Aug 2023 09:59:42 -0700 Subject: [PATCH 02/16] More pyright fixes --- sdk/core/azure-core/azure/core/_pipeline_client.py | 3 ++- sdk/core/azure-core/azure/core/_pipeline_client_async.py | 3 ++- sdk/core/azure-core/azure/core/credentials_async.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index 7ede5404e283..e1f1d3c0faa0 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -172,7 +172,8 @@ def _build_pipeline( policies = policies_1 if transport is None: - from .pipeline.transport import RequestsTransport # pylint: disable=no-name-in-module + # Use private import for better typing, mypy and pyright don't like PEP562 + from .pipeline.transport._requests_basic import RequestsTransport transport = RequestsTransport(**kwargs) diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 3bda4a7c5b7d..e614f7900f95 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -254,7 +254,8 @@ def _build_pipeline( policies = policies_1 if not transport: - from .pipeline.transport import AioHttpTransport # pylint: disable=no-name-in-module + # Use private import for better typing, mypy and pyright don't like PEP562 + from .pipeline.transport._aiohttp import AioHttpTransport transport = AioHttpTransport(**kwargs) diff --git a/sdk/core/azure-core/azure/core/credentials_async.py b/sdk/core/azure-core/azure/core/credentials_async.py index 39ec2bc09569..79217c96c9e7 100644 --- a/sdk/core/azure-core/azure/core/credentials_async.py +++ b/sdk/core/azure-core/azure/core/credentials_async.py @@ -29,6 +29,7 @@ async def get_token( :rtype: AccessToken :return: An AccessToken instance containing the token string and its expiration time in Unix time. """ + ... async def close(self) -> None: pass From 1c53a760a537e1742f0a4e3717f743b0e9af0a03 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Mon, 28 Aug 2023 13:54:28 -0700 Subject: [PATCH 03/16] Fix unbound variable in CloudEvent (thanks pyright) --- sdk/core/azure-core/azure/core/messaging.py | 36 ++++++++++----------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/sdk/core/azure-core/azure/core/messaging.py b/sdk/core/azure-core/azure/core/messaging.py index 426f45b733a3..6306ceb34182 100644 --- a/sdk/core/azure-core/azure/core/messaging.py +++ b/sdk/core/azure-core/azure/core/messaging.py @@ -192,26 +192,26 @@ def from_dict(cls, event: Dict[str, Any]) -> CloudEvent[DataType]: except KeyError as err: # https://github.com/cloudevents/spec Cloud event spec requires source, type, # specversion. We autopopulate everything other than source, type. - if not all(_ in event for _ in ("source", "type")): - if all( - _ in event - for _ in ( - "subject", - "eventType", - "data", - "dataVersion", - "id", - "eventTime", - ) - ): - raise ValueError( - "The event you are trying to parse follows the Eventgrid Schema. You can parse" - + " EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library." - ) from err + # So we will assume the KeyError is coming from source/type access. + if all( + _ in event + for _ in ( + "subject", + "eventType", + "data", + "dataVersion", + "id", + "eventTime", + ) + ): raise ValueError( - "The event does not conform to the cloud event spec https://github.com/cloudevents/spec." - + " The `source` and `type` params are required." + "The event you are trying to parse follows the Eventgrid Schema. You can parse" + + " EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library." ) from err + raise ValueError( + "The event does not conform to the cloud event spec https://github.com/cloudevents/spec." + + " The `source` and `type` params are required." + ) from err return event_obj @classmethod From f529968584576c7d86806b14def2a05974affe12 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Mon, 28 Aug 2023 14:32:50 -0700 Subject: [PATCH 04/16] Type correctly await_result --- .../azure/core/pipeline/_tools_async.py | 21 ++++++++-- .../policies/_authentication_async.py | 40 +++++++++---------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index 841de5b84a48..0ee766f5ff69 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -23,13 +23,27 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload +from typing_extensions import ParamSpec if TYPE_CHECKING: from ..rest import AsyncHttpResponse as RestAsyncHttpResponse +P = ParamSpec("P") +T = TypeVar("T") -async def await_result(func, *args, **kwargs): + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: + ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: """If func returns an awaitable, await it. :param func: The function to run. @@ -41,8 +55,7 @@ async def await_result(func, *args, **kwargs): """ result = func(*args, **kwargs) if hasattr(result, "__await__"): - # type ignore on await: https://github.com/python/mypy/issues/7587 - return await result # type: ignore + return await result return result diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py index d75e3cbf96d4..156543720485 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_authentication_async.py @@ -92,28 +92,28 @@ async def send( await await_result(self.on_request, request) try: response = await self.next.send(request) - await await_result(self.on_response, request, response) except Exception: # pylint:disable=broad-except - handled = await await_result(self.on_exception, request) - if not handled: - raise + await await_result(self.on_exception, request) + raise else: - if response.http_response.status_code == 401: - self._token = None # any cached token is invalid - if "WWW-Authenticate" in response.http_response.headers: - request_authorized = await self.on_challenge(request, response) - if request_authorized: - # if we receive a challenge response, we retrieve a new token - # which matches the new target. In this case, we don't want to remove - # token from the request so clear the 'insecure_domain_change' tag - request.context.options.pop("insecure_domain_change", False) - try: - response = await self.next.send(request) - await await_result(self.on_response, request, response) - except Exception: # pylint:disable=broad-except - handled = await await_result(self.on_exception, request) - if not handled: - raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + else: + await await_result(self.on_response, request, response) return response From b9778bef020245dd1df29755e3ba8cd0aa05e808 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 30 Aug 2023 14:53:15 -0700 Subject: [PATCH 05/16] Exception handler is NoReturn --- sdk/core/azure-core/azure/core/exceptions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/exceptions.py b/sdk/core/azure-core/azure/core/exceptions.py index c768bf26bbd6..b87ee2a85869 100644 --- a/sdk/core/azure-core/azure/core/exceptions.py +++ b/sdk/core/azure-core/azure/core/exceptions.py @@ -40,6 +40,7 @@ TypeVar, Generic, Dict, + NoReturn, TYPE_CHECKING, ) from typing_extensions import Protocol, runtime_checkable @@ -79,7 +80,7 @@ ] -def raise_with_traceback(exception: Callable, *args: Any, **kwargs: Any) -> None: +def raise_with_traceback(exception: Callable, *args: Any, **kwargs: Any) -> NoReturn: """Raise exception with a specified traceback. This MUST be called inside a "except" clause. From f4e49110c78d35d883446213101f46f958c9257b Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 30 Aug 2023 14:53:41 -0700 Subject: [PATCH 06/16] Variable name clarity --- sdk/core/azure-core/azure/core/messaging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/messaging.py b/sdk/core/azure-core/azure/core/messaging.py index 6306ceb34182..1efd544c2b5f 100644 --- a/sdk/core/azure-core/azure/core/messaging.py +++ b/sdk/core/azure-core/azure/core/messaging.py @@ -194,8 +194,8 @@ def from_dict(cls, event: Dict[str, Any]) -> CloudEvent[DataType]: # specversion. We autopopulate everything other than source, type. # So we will assume the KeyError is coming from source/type access. if all( - _ in event - for _ in ( + key in event + for key in ( "subject", "eventType", "data", From 43699798e2cf0a535a5141e3175d7fd8ae65190d Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 30 Aug 2023 16:00:42 -0700 Subject: [PATCH 07/16] tools_async clean-up --- sdk/core/azure-core/azure/core/pipeline/_tools_async.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index 0ee766f5ff69..25c0a9df75c8 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -23,7 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload +from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload, cast from typing_extensions import ParamSpec if TYPE_CHECKING: @@ -53,9 +53,13 @@ async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, :rtype: any :return: The result of the function """ - result = func(*args, **kwargs) + result: Union[T, Awaitable[T]] = func(*args, **kwargs) + # pyright has issue with narrowing types here + # https://github.com/microsoft/pyright/issues/5860 if hasattr(result, "__await__"): + result = cast(Awaitable[T], result) return await result + result = cast(T, result) return result From 0437c4f8def71705ab9af12e78489db732818b95 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Fri, 1 Sep 2023 17:34:00 +0000 Subject: [PATCH 08/16] Settings as clean as possible --- sdk/core/azure-core/azure/core/settings.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sdk/core/azure-core/azure/core/settings.py b/sdk/core/azure-core/azure/core/settings.py index d41ffbc10964..ea889f42b04f 100644 --- a/sdk/core/azure-core/azure/core/settings.py +++ b/sdk/core/azure-core/azure/core/settings.py @@ -65,9 +65,9 @@ def convert_bool(value: Union[str, bool]) -> bool: :raises ValueError: If conversion to bool fails """ - if value in (True, False): - return cast(bool, value) - val = cast(str, value).lower() + if isinstance(value, bool): + return value + val = value.lower() if val in ["yes", "1", "on", "true", "True"]: return True if val in ["no", "0", "off", "false", "False"]: @@ -103,9 +103,11 @@ def convert_logging(value: Union[str, int]) -> int: :raises ValueError: If conversion to log level fails """ - if value in set(_levels.values()): - return cast(int, value) - val = cast(str, value).upper() + if isinstance(value, int): + # If it's an int, return it. We don't need to check if it's in _levels, as custom int levels are allowed. + # https://docs.python.org/3/library/logging.html#levels + return value + val = value.upper() level = _levels.get(val) if not level: raise ValueError("Cannot convert {} to log level, valid values are: {}".format(value, ", ".join(_levels))) @@ -183,7 +185,6 @@ def convert_tracing_impl(value: Optional[Union[str, Type[AbstractSpan]]]) -> Opt ) if not isinstance(value, str): - value = cast(Type[AbstractSpan], value) return value value = value.lower() @@ -271,7 +272,7 @@ def __call__(self, value: Optional[ValidInputType] = None) -> ValueType: return self._convert(value) # 3. previously user-set value - if self._user_value is not _unset: + if not isinstance(self._user_value, _Unset): return self._convert(self._user_value) # 2. environment variable @@ -283,7 +284,7 @@ def __call__(self, value: Optional[ValidInputType] = None) -> ValueType: return self._convert(self._system_hook()) # 0. implicit default - if self._default is not _unset: + if not isinstance(self._default, _Unset): return self._convert(self._default) raise RuntimeError("No configured value found for setting %r" % self._name) From d8b2f4d39c0d564ea1bacef5d68ef2dad55fd802 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Fri, 1 Sep 2023 18:16:50 +0000 Subject: [PATCH 09/16] Fix a few pyright warnings --- .../azure/core/pipeline/policies/_distributed_tracing.py | 2 +- sdk/core/azure-core/azure/core/rest/_http_response_impl.py | 4 ++-- .../azure-core/azure/core/utils/_connection_string_parser.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py b/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py index 7c2a29656459..6a7619eb1cbf 100644 --- a/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py +++ b/sdk/core/azure-core/azure/core/pipeline/policies/_distributed_tracing.py @@ -26,7 +26,7 @@ """Traces network calls using the implementation library from the settings.""" import logging import sys -import urllib +import urllib.parse from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type from types import TracebackType diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py index 72f59b721857..53be80045848 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py @@ -344,7 +344,7 @@ def raise_for_status(self) -> None: If response is good, does nothing. """ - if cast(int, self.status_code) >= 400: + if self.status_code >= 400: raise HttpResponseError(response=self) @property @@ -415,7 +415,7 @@ def iter_bytes(self, **kwargs) -> Iterator[bytes]: :rtype: Iterator[str] """ if self._content is not None: - chunk_size = cast(int, self._block_size) + chunk_size = self._block_size for i in range(0, len(self.content), chunk_size): yield self.content[i : i + chunk_size] else: diff --git a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py index c088711cbcb4..61494b487181 100644 --- a/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py +++ b/sdk/core/azure-core/azure/core/utils/_connection_string_parser.py @@ -26,7 +26,7 @@ def parse_connection_string(conn_str: str, case_sensitive_keys: bool = False) -> cs_args = [s.split("=", 1) for s in conn_str.strip().rstrip(";").split(";")] if any(len(tup) != 2 or not all(tup) for tup in cs_args): raise ValueError("Connection string is either blank or malformed.") - args_dict = dict(cs_args) # type: ignore + args_dict = dict(cs_args) if len(cs_args) != len(args_dict): raise ValueError("Connection string is either blank or malformed.") From 7544bab10dad77f885955bc0d434577336f6c5ad Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Fri, 1 Sep 2023 18:17:48 +0000 Subject: [PATCH 10/16] Fix streaming typing --- .../azure/core/pipeline/transport/_aiohttp.py | 46 +++++++++++++++---- .../azure/core/pipeline/transport/_base.py | 4 +- .../core/pipeline/transport/_base_async.py | 6 +-- 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index ad52a867f42b..1b56533e7d0e 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -23,6 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +from __future__ import annotations import sys from typing import Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload, cast, Union, Type from types import TracebackType @@ -56,6 +57,7 @@ HttpRequest as RestHttpRequest, AsyncHttpResponse as RestAsyncHttpResponse, ) + from ...rest._aiohttp import RestAioHttpTransportResponse # Matching requests, because why not? CONTENT_CHUNK_SIZE = 10 * 1024 @@ -188,7 +190,7 @@ async def send(self, request: HttpRequest, **config: Any) -> AsyncHttpResponse: """ @overload - async def send(self, request: "RestHttpRequest", **config: Any) -> "RestAsyncHttpResponse": + async def send(self, request: RestHttpRequest, **config: Any) -> RestAsyncHttpResponse: """Send the `azure.core.rest` request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. @@ -206,8 +208,8 @@ async def send(self, request: "RestHttpRequest", **config: Any) -> "RestAsyncHtt """ async def send( - self, request: Union[HttpRequest, "RestHttpRequest"], **config - ) -> Union[AsyncHttpResponse, "RestAsyncHttpResponse"]: + self, request: Union[HttpRequest, RestHttpRequest], **config + ) -> Union[AsyncHttpResponse, RestAsyncHttpResponse]: """Send the request using this HTTP sender. Will pre-load the body into memory to be available with a sync method. @@ -240,7 +242,7 @@ async def send( config["proxy"] = proxies[protocol] break - response: Optional[Union[AsyncHttpResponse, "RestAsyncHttpResponse"]] = None + response: Optional[Union[AsyncHttpResponse, RestAsyncHttpResponse]] = None config["ssl"] = self._build_ssl_config( cert=config.pop("connection_cert", self.connection_config.cert), verify=config.pop("connection_verify", self.connection_config.verify), @@ -262,7 +264,7 @@ async def send( data=self._get_request_data(request), timeout=socket_timeout, allow_redirects=False, - **config + **config, ) if _is_rest(request): from azure.core.rest._aiohttp import RestAioHttpTransportResponse @@ -307,7 +309,33 @@ class AioHttpStreamDownloadGenerator(AsyncIterator): on the *content-encoding* header. """ - def __init__(self, pipeline: AsyncPipeline, response: AsyncHttpResponse, *, decompress: bool = True) -> None: + @overload + def __init__( + self, + pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], + response: AioHttpTransportResponse, + *, + decompress: bool = True, + ) -> None: + ... + + @overload + def __init__( + self, + pipeline: AsyncPipeline[RestHttpRequest, RestAsyncHttpResponse], + response: RestAioHttpTransportResponse, + *, + decompress: bool = True, + ) -> None: + ... + + def __init__( + self, + pipeline: AsyncPipeline, + response: Union[AioHttpTransportResponse, RestAioHttpTransportResponse], + *, + decompress: bool = True, + ) -> None: self.pipeline = pipeline self.request = response.request self.response = response @@ -380,7 +408,7 @@ def __init__( aiohttp_response: aiohttp.ClientResponse, block_size: Optional[int] = None, *, - decompress: bool = True + decompress: bool = True, ) -> None: super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) # https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse @@ -462,7 +490,9 @@ async def load_body(self) -> None: except aiohttp.client_exceptions.ClientError as err: raise ServiceRequestError(err, error=err) from err - def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: + def stream_download( + self, pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], **kwargs + ) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. :param pipeline: The pipeline object diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 08146b9b4a29..247433104151 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -52,6 +52,7 @@ from http.client import HTTPResponse as _HTTPResponse +from azure.core.pipeline import Pipeline from azure.core.exceptions import HttpResponseError from ...utils._utils import case_insensitive_dict from ...utils._pipeline_transport_rest_shared import ( @@ -68,7 +69,6 @@ HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") -PipelineType = TypeVar("PipelineType") DataType = Union[bytes, str, Dict[str, Union[str, int]]] _LOGGER = logging.getLogger(__name__) @@ -488,7 +488,7 @@ def __repr__(self) -> str: class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method - def stream_download(self, pipeline: PipelineType, **kwargs: Any) -> Iterator[bytes]: + def stream_download(self, pipeline: Pipeline[HttpRequest, "HttpResponse"], **kwargs: Any) -> Iterator[bytes]: """Generator for streaming request body data. Should be implemented by sub-classes if streaming download diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index 0ed17916d4de..55d6060c645b 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -32,7 +32,6 @@ TypeVar, Generic, Any, - TYPE_CHECKING, AsyncContextManager, Optional, Type, @@ -40,10 +39,9 @@ from types import TracebackType from ._base import _HttpResponseBase, _HttpClientTransportResponse, HttpRequest +from .._base_async import AsyncPipeline from ...utils._pipeline_transport_rest_shared_async import _PartGenerator -if TYPE_CHECKING: - from ..._pipeline_client_async import AsyncPipelineClient AsyncHTTPResponseType = TypeVar("AsyncHTTPResponseType") HTTPResponseType = TypeVar("HTTPResponseType") @@ -76,7 +74,7 @@ class AsyncHttpResponse(_HttpResponseBase, AsyncContextManager["AsyncHttpRespons """ def stream_download( - self, pipeline: AsyncPipelineClient[HttpRequest, "AsyncHttpResponse"], **kwargs: Any + self, pipeline: AsyncPipeline[HttpRequest, "AsyncHttpResponse"], **kwargs: Any ) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. From 884b62e222d4894a96ca4fc8d8b8fc0bd5c1d9e2 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Fri, 1 Sep 2023 18:18:39 +0000 Subject: [PATCH 11/16] Session fix in requests transport --- .../azure/core/pipeline/transport/_requests_basic.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index b5f8c47711e0..36d2746be3bb 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -24,7 +24,7 @@ # # -------------------------------------------------------------------------- import logging -from typing import Iterator, Optional, Union, TypeVar, overload, TYPE_CHECKING +from typing import Iterator, Optional, Union, TypeVar, overload, cast, TYPE_CHECKING from urllib3.util.retry import Retry from urllib3.exceptions import ( DecodeError as CoreDecodeError, @@ -247,6 +247,8 @@ class RequestsTransport(HttpTransport): def __init__(self, **kwargs) -> None: self.session = kwargs.get("session", None) self._session_owner = kwargs.get("session_owner", True) + if not self._session_owner and not self.session: + raise ValueError("session_owner cannot be False if no session is provided") self.connection_config = ConnectionConfiguration(**kwargs) self._use_env_settings = kwargs.pop("use_env_settings", True) @@ -274,6 +276,8 @@ def open(self): if not self.session and self._session_owner: self.session = requests.Session() self._init_session(self.session) + # pyright has trouble to understand that self.session is not None, since we raised at worst in the init + self.session = cast(requests.Session, self.session) def close(self): if self._session_owner and self.session: From 85b583dab2e573610e40d37e03bd48435e1f6328 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Fri, 1 Sep 2023 18:57:24 +0000 Subject: [PATCH 12/16] Pylint --- sdk/core/azure-core/azure/core/rest/_http_response_impl.py | 2 +- sdk/core/azure-core/azure/core/settings.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py index 53be80045848..8754985820ba 100644 --- a/sdk/core/azure-core/azure/core/rest/_http_response_impl.py +++ b/sdk/core/azure-core/azure/core/rest/_http_response_impl.py @@ -24,7 +24,7 @@ # # -------------------------------------------------------------------------- from json import loads -from typing import cast, Any, Optional, Iterator, MutableMapping, Callable +from typing import Any, Optional, Iterator, MutableMapping, Callable from http.client import HTTPResponse as _HTTPResponse from ._helpers import ( get_charset_encoding, diff --git a/sdk/core/azure-core/azure/core/settings.py b/sdk/core/azure-core/azure/core/settings.py index ea889f42b04f..921e76f37bd6 100644 --- a/sdk/core/azure-core/azure/core/settings.py +++ b/sdk/core/azure-core/azure/core/settings.py @@ -31,7 +31,7 @@ import logging import os import sys -from typing import Type, Optional, Callable, cast, Union, Dict, Any, TypeVar, Tuple, Generic, Mapping, List +from typing import Type, Optional, Callable, Union, Dict, Any, TypeVar, Tuple, Generic, Mapping, List from azure.core.tracing import AbstractSpan ValidInputType = TypeVar("ValidInputType") From a3222d8d9a4b0f972ea160c4358359c3e5ead274 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Fri, 1 Sep 2023 20:08:40 +0000 Subject: [PATCH 13/16] Circular dependencies --- sdk/core/azure-core/azure/core/pipeline/transport/_base.py | 7 ++++++- .../azure/core/pipeline/transport/_base_async.py | 6 +++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 247433104151..6b0839888a39 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -23,6 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- +from __future__ import annotations import abc from email.message import Message import json @@ -48,11 +49,11 @@ Sequence, MutableMapping, ContextManager, + TYPE_CHECKING, ) from http.client import HTTPResponse as _HTTPResponse -from azure.core.pipeline import Pipeline from azure.core.exceptions import HttpResponseError from ...utils._utils import case_insensitive_dict from ...utils._pipeline_transport_rest_shared import ( @@ -71,6 +72,10 @@ HTTPRequestType = TypeVar("HTTPRequestType") DataType = Union[bytes, str, Dict[str, Union[str, int]]] +if TYPE_CHECKING: + # We need a transport to define a pipeline, this "if" avoid a circular import + from azure.core.pipeline import Pipeline + _LOGGER = logging.getLogger(__name__) binary_type = str diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index 55d6060c645b..7579d71cc69e 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -35,11 +35,11 @@ AsyncContextManager, Optional, Type, + TYPE_CHECKING, ) from types import TracebackType from ._base import _HttpResponseBase, _HttpClientTransportResponse, HttpRequest -from .._base_async import AsyncPipeline from ...utils._pipeline_transport_rest_shared_async import _PartGenerator @@ -47,6 +47,10 @@ HTTPResponseType = TypeVar("HTTPResponseType") HTTPRequestType = TypeVar("HTTPRequestType") +if TYPE_CHECKING: + # We need a transport to define a pipeline, this "if" avoid a circular import + from .._base_async import AsyncPipeline + class _ResponseStopIteration(Exception): pass From b7fe57f33905a9708eb7bc98a01e80ae165c5827 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 6 Sep 2023 21:33:15 +0000 Subject: [PATCH 14/16] Make aiohttp check session as well --- .../azure-core/azure/core/pipeline/transport/_aiohttp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index 1b56533e7d0e..9c525533211a 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -93,6 +93,8 @@ def __init__( self._loop = loop self._session_owner = session_owner self.session = session + if not self._session_owner and not self.session: + raise ValueError("session_owner cannot be False if no session is provided") self.connection_config = ConnectionConfiguration(**kwargs) self._use_env_settings = kwargs.pop("use_env_settings", True) @@ -120,8 +122,9 @@ async def open(self): if self._loop is not None: clientsession_kwargs["loop"] = self._loop self.session = aiohttp.ClientSession(**clientsession_kwargs) - if self.session is not None: - await self.session.__aenter__() + # pyright has trouble to understand that self.session is not None, since we raised at worst in the init + self.session = cast(aiohttp.ClientSession, self.session) + await self.session.__aenter__() async def close(self): """Closes the connection.""" From c7fe560fe6f9f329fab107982d5cc2dde7320ded Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 6 Sep 2023 21:39:55 +0000 Subject: [PATCH 15/16] Improve await_result --- sdk/core/azure-core/azure/core/pipeline/_tools.py | 11 +++++++---- .../azure-core/azure/core/pipeline/_tools_async.py | 8 ++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools.py b/sdk/core/azure-core/azure/core/pipeline/_tools.py index de90f06f6784..e6b8545f0d8f 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools.py @@ -24,15 +24,18 @@ # # -------------------------------------------------------------------------- from __future__ import annotations -from typing import TYPE_CHECKING, Union -from typing_extensions import TypeGuard +from typing import TYPE_CHECKING, Union, Callable, TypeVar +from typing_extensions import TypeGuard, ParamSpec if TYPE_CHECKING: - from typing import Any from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse -def await_result(func, *args, **kwargs): +P = ParamSpec("P") +T = TypeVar("T") + + +def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: """If func returns an awaitable, raise that this runner can't handle it. :param func: The function to run. diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index 25c0a9df75c8..f7e3f74a62f9 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -53,13 +53,9 @@ async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, :rtype: any :return: The result of the function """ - result: Union[T, Awaitable[T]] = func(*args, **kwargs) - # pyright has issue with narrowing types here - # https://github.com/microsoft/pyright/issues/5860 - if hasattr(result, "__await__"): - result = cast(Awaitable[T], result) + result = func(*args, **kwargs) + if isinstance(result, Awaitable): return await result - result = cast(T, result) return result From c7d47b382600aa66bf90b759fed617aa7a6aff42 Mon Sep 17 00:00:00 2001 From: Laurent Mazuel Date: Wed, 6 Sep 2023 23:08:42 +0000 Subject: [PATCH 16/16] pylint --- sdk/core/azure-core/azure/core/pipeline/_tools_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py index f7e3f74a62f9..3395b5cbbbe7 100644 --- a/sdk/core/azure-core/azure/core/pipeline/_tools_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/_tools_async.py @@ -23,7 +23,7 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload, cast +from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload from typing_extensions import ParamSpec if TYPE_CHECKING: