Skip to content

Pyright fixes for azure-core #31766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 7, 2023
Merged
3 changes: 2 additions & 1 deletion sdk/core/azure-core/azure/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def _build_pipeline(
policies = policies_1

if transport is None:
from .pipeline.transport import RequestsTransport # pylint: disable=no-name-in-module
# Use private import for better typing, mypy and pyright don't like PEP562
from .pipeline.transport._requests_basic import RequestsTransport

transport = RequestsTransport(**kwargs)

Expand Down
3 changes: 2 additions & 1 deletion sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def _build_pipeline(
policies = policies_1

if not transport:
from .pipeline.transport import AioHttpTransport # pylint: disable=no-name-in-module
# Use private import for better typing, mypy and pyright don't like PEP562
from .pipeline.transport._aiohttp import AioHttpTransport

transport = AioHttpTransport(**kwargs)

Expand Down
1 change: 1 addition & 0 deletions sdk/core/azure-core/azure/core/credentials_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ async def get_token(
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
...

async def close(self) -> None:
pass
Expand Down
11 changes: 6 additions & 5 deletions sdk/core/azure-core/azure/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
TypeVar,
Generic,
Dict,
NoReturn,
TYPE_CHECKING,
)
from typing_extensions import Protocol, runtime_checkable
Expand Down Expand Up @@ -79,7 +80,7 @@
]


def raise_with_traceback(exception: Callable, *args: Any, **kwargs: Any) -> None:
def raise_with_traceback(exception: Callable, *args: Any, **kwargs: Any) -> NoReturn:
"""Raise exception with a specified traceback.
This MUST be called inside a "except" clause.

Expand Down Expand Up @@ -113,18 +114,18 @@ class _HttpResponseCommonAPI(Protocol):

@property
def reason(self) -> Optional[str]:
pass
...

@property
def status_code(self) -> Optional[int]:
pass
...

def text(self) -> str:
pass
...

@property
def request(self) -> object: # object as type, since all we need is str() on it
pass
...


class ErrorMap(Generic[KeyType, ValueType]):
Expand Down
36 changes: 18 additions & 18 deletions sdk/core/azure-core/azure/core/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,26 +192,26 @@ def from_dict(cls, event: Dict[str, Any]) -> CloudEvent[DataType]:
except KeyError as err:
# https://github.com/cloudevents/spec Cloud event spec requires source, type,
# specversion. We autopopulate everything other than source, type.
if not all(_ in event for _ in ("source", "type")):
if all(
_ in event
for _ in (
"subject",
"eventType",
"data",
"dataVersion",
"id",
"eventTime",
)
):
raise ValueError(
"The event you are trying to parse follows the Eventgrid Schema. You can parse"
+ " EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library."
) from err
# So we will assume the KeyError is coming from source/type access.
if all(
key in event
for key in (
"subject",
"eventType",
"data",
"dataVersion",
"id",
"eventTime",
)
):
raise ValueError(
"The event does not conform to the cloud event spec https://github.com/cloudevents/spec."
+ " The `source` and `type` params are required."
"The event you are trying to parse follows the Eventgrid Schema. You can parse"
+ " EventGrid events using EventGridEvent.from_dict method in the azure-eventgrid library."
) from err
raise ValueError(
"The event does not conform to the cloud event spec https://github.com/cloudevents/spec."
+ " The `source` and `type` params are required."
) from err
return event_obj

@classmethod
Expand Down
17 changes: 11 additions & 6 deletions sdk/core/azure-core/azure/core/pipeline/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING, Union, Callable, TypeVar
from typing_extensions import TypeGuard, ParamSpec

if TYPE_CHECKING:
from typing import Any
from azure.core.rest import HttpResponse as RestHttpResponse
from azure.core.rest import HttpResponse, HttpRequest, AsyncHttpResponse


def await_result(func, *args, **kwargs):
P = ParamSpec("P")
T = TypeVar("T")


def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
"""If func returns an awaitable, raise that this runner can't handle it.

:param func: The function to run.
Expand All @@ -47,7 +52,7 @@ def await_result(func, *args, **kwargs):
return result


def is_rest(obj) -> bool:
def is_rest(obj: object) -> TypeGuard[Union[HttpRequest, HttpResponse, AsyncHttpResponse]]:
"""Return whether a request or a response is a rest request / response.

Checking whether the response has the object content can sometimes result
Expand All @@ -63,7 +68,7 @@ def is_rest(obj) -> bool:
return hasattr(obj, "is_stream_consumed") or hasattr(obj, "content")


def handle_non_stream_rest_response(response: "RestHttpResponse") -> None:
def handle_non_stream_rest_response(response: HttpResponse) -> None:
"""Handle reading and closing of non stream rest responses.
For our new rest responses, we have to call .read() and .close() for our non-stream
responses. This way, we load in the body for users to access.
Expand Down
23 changes: 18 additions & 5 deletions sdk/core/azure-core/azure/core/pipeline/_tools_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,27 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, TypeVar, Awaitable, Union, overload
from typing_extensions import ParamSpec

if TYPE_CHECKING:
from ..rest import AsyncHttpResponse as RestAsyncHttpResponse

P = ParamSpec("P")
T = TypeVar("T")

async def await_result(func, *args, **kwargs):

@overload
async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T:
...


@overload
async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
...


async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
"""If func returns an awaitable, await it.

:param func: The function to run.
Expand All @@ -40,9 +54,8 @@ async def await_result(func, *args, **kwargs):
:return: The result of the function
"""
result = func(*args, **kwargs)
if hasattr(result, "__await__"):
# type ignore on await: https://github.com/python/mypy/issues/7587
return await result # type: ignore
if isinstance(result, Awaitable):
return await result
return result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,28 @@ async def send(
await await_result(self.on_request, request)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
await await_result(self.on_exception, request)
raise
else:
if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
# if we receive a challenge response, we retrieve a new token
# which matches the new target. In this case, we don't want to remove
# token from the request so clear the 'insecure_domain_change' tag
request.context.options.pop("insecure_domain_change", False)
try:
response = await self.next.send(request)
await await_result(self.on_response, request, response)
except Exception: # pylint:disable=broad-except
handled = await await_result(self.on_exception, request)
if not handled:
raise
await await_result(self.on_response, request, response)

if response.http_response.status_code == 401:
self._token = None # any cached token is invalid
if "WWW-Authenticate" in response.http_response.headers:
request_authorized = await self.on_challenge(request, response)
if request_authorized:
# if we receive a challenge response, we retrieve a new token
# which matches the new target. In this case, we don't want to remove
# token from the request so clear the 'insecure_domain_change' tag
request.context.options.pop("insecure_domain_change", False)
try:
response = await self.next.send(request)
except Exception: # pylint:disable=broad-except
await await_result(self.on_exception, request)
raise
else:
await await_result(self.on_response, request, response)

return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""Traces network calls using the implementation library from the settings."""
import logging
import sys
import urllib
import urllib.parse
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union, Any, Type
from types import TracebackType

Expand Down
53 changes: 43 additions & 10 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from __future__ import annotations
import sys
from typing import Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload, cast, Union, Type
from types import TracebackType
Expand Down Expand Up @@ -56,6 +57,7 @@
HttpRequest as RestHttpRequest,
AsyncHttpResponse as RestAsyncHttpResponse,
)
from ...rest._aiohttp import RestAioHttpTransportResponse

# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
Expand Down Expand Up @@ -91,6 +93,8 @@ def __init__(
self._loop = loop
self._session_owner = session_owner
self.session = session
if not self._session_owner and not self.session:
raise ValueError("session_owner cannot be False if no session is provided")
self.connection_config = ConnectionConfiguration(**kwargs)
self._use_env_settings = kwargs.pop("use_env_settings", True)

Expand Down Expand Up @@ -118,8 +122,9 @@ async def open(self):
if self._loop is not None:
clientsession_kwargs["loop"] = self._loop
self.session = aiohttp.ClientSession(**clientsession_kwargs)
if self.session is not None:
await self.session.__aenter__()
# pyright has trouble to understand that self.session is not None, since we raised at worst in the init
self.session = cast(aiohttp.ClientSession, self.session)
await self.session.__aenter__()

async def close(self):
"""Closes the connection."""
Expand Down Expand Up @@ -188,7 +193,7 @@ async def send(self, request: HttpRequest, **config: Any) -> AsyncHttpResponse:
"""

@overload
async def send(self, request: "RestHttpRequest", **config: Any) -> "RestAsyncHttpResponse":
async def send(self, request: RestHttpRequest, **config: Any) -> RestAsyncHttpResponse:
"""Send the `azure.core.rest` request using this HTTP sender.

Will pre-load the body into memory to be available with a sync method.
Expand All @@ -206,8 +211,8 @@ async def send(self, request: "RestHttpRequest", **config: Any) -> "RestAsyncHtt
"""

async def send(
self, request: Union[HttpRequest, "RestHttpRequest"], **config
) -> Union[AsyncHttpResponse, "RestAsyncHttpResponse"]:
self, request: Union[HttpRequest, RestHttpRequest], **config
) -> Union[AsyncHttpResponse, RestAsyncHttpResponse]:
"""Send the request using this HTTP sender.

Will pre-load the body into memory to be available with a sync method.
Expand Down Expand Up @@ -240,7 +245,7 @@ async def send(
config["proxy"] = proxies[protocol]
break

response: Optional[Union[AsyncHttpResponse, "RestAsyncHttpResponse"]] = None
response: Optional[Union[AsyncHttpResponse, RestAsyncHttpResponse]] = None
config["ssl"] = self._build_ssl_config(
cert=config.pop("connection_cert", self.connection_config.cert),
verify=config.pop("connection_verify", self.connection_config.verify),
Expand All @@ -262,7 +267,7 @@ async def send(
data=self._get_request_data(request),
timeout=socket_timeout,
allow_redirects=False,
**config
**config,
)
if _is_rest(request):
from azure.core.rest._aiohttp import RestAioHttpTransportResponse
Expand Down Expand Up @@ -307,7 +312,33 @@ class AioHttpStreamDownloadGenerator(AsyncIterator):
on the *content-encoding* header.
"""

def __init__(self, pipeline: AsyncPipeline, response: AsyncHttpResponse, *, decompress: bool = True) -> None:
@overload
def __init__(
self,
pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse],
response: AioHttpTransportResponse,
*,
decompress: bool = True,
) -> None:
...

@overload
def __init__(
self,
pipeline: AsyncPipeline[RestHttpRequest, RestAsyncHttpResponse],
response: RestAioHttpTransportResponse,
*,
decompress: bool = True,
) -> None:
...

def __init__(
self,
pipeline: AsyncPipeline,
response: Union[AioHttpTransportResponse, RestAioHttpTransportResponse],
*,
decompress: bool = True,
) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
Expand Down Expand Up @@ -380,7 +411,7 @@ def __init__(
aiohttp_response: aiohttp.ClientResponse,
block_size: Optional[int] = None,
*,
decompress: bool = True
decompress: bool = True,
) -> None:
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size)
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
Expand Down Expand Up @@ -462,7 +493,9 @@ async def load_body(self) -> None:
except aiohttp.client_exceptions.ClientError as err:
raise ServiceRequestError(err, error=err) from err

def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
def stream_download(
self, pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], **kwargs
) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.

:param pipeline: The pipeline object
Expand Down
Loading