diff --git a/google/auth/transport/aiohttp_requests.py b/google/auth/transport/aiohttp_requests.py index 46816ea5e..29ecb3ad5 100644 --- a/google/auth/transport/aiohttp_requests.py +++ b/google/auth/transport/aiohttp_requests.py @@ -18,6 +18,7 @@ import asyncio import functools +import zlib import aiohttp import six @@ -31,6 +32,57 @@ _DEFAULT_TIMEOUT = 180 # in seconds +class _CombinedResponse(transport.Response): + """ + In order to more closely resemble the `requests` interface, where a raw + and deflated content could be accessed at once, this class lazily reads the + stream in `transport.Response` so both return forms can be used. + + The gzip and deflate transfer-encodings are automatically decoded for you + because the default parameter for autodecompress into the ClientSession is set + to False, and therefore we add this class to act as a wrapper for a user to be + able to access both the raw and decoded response bodies - mirroring the sync + implementation. + """ + + def __init__(self, response): + self._response = response + self._raw_content = None + + def _is_compressed(self): + headers = self._client_response.headers + return "Content-Encoding" in headers and ( + headers["Content-Encoding"] == "gzip" + or headers["Content-Encoding"] == "deflate" + ) + + @property + def status(self): + return self._response.status + + @property + def headers(self): + return self._response.headers + + @property + def data(self): + return self._response.content + + async def raw_content(self): + if self._raw_content is None: + self._raw_content = await self._response.content.read() + return self._raw_content + + async def content(self): + if self._raw_content is None: + self._raw_content = await self._response.content.read() + if self._is_compressed: + d = zlib.decompressobj(zlib.MAX_WBITS | 32) + decompressed = d.decompress(self._raw_content) + return decompressed + return self._raw_content + + class _Response(transport.Response): """ Requests transport response adapter. @@ -79,7 +131,6 @@ class Request(transport.Request): """ def __init__(self, session=None): - self.session = None async def __call__( @@ -89,7 +140,7 @@ async def __call__( body=None, headers=None, timeout=_DEFAULT_TIMEOUT, - **kwargs + **kwargs, ): """ Make an HTTP request using aiohttp. @@ -115,12 +166,14 @@ async def __call__( try: if self.session is None: # pragma: NO COVER - self.session = aiohttp.ClientSession() # pragma: NO COVER + self.session = aiohttp.ClientSession( + auto_decompress=False + ) # pragma: NO COVER requests._LOGGER.debug("Making request: %s %s", method, url) response = await self.session.request( method, url, data=body, headers=headers, timeout=timeout, **kwargs ) - return _Response(response) + return _CombinedResponse(response) except aiohttp.ClientError as caught_exc: new_exc = exceptions.TransportError(caught_exc) @@ -175,6 +228,7 @@ def __init__( max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, refresh_timeout=None, auth_request=None, + auto_decompress=False, ): super(AuthorizedSession, self).__init__() self.credentials = credentials @@ -186,6 +240,7 @@ def __init__( self._auth_request_session = None self._loop = asyncio.get_event_loop() self._refresh_lock = asyncio.Lock() + self._auto_decompress = auto_decompress async def request( self, @@ -195,7 +250,8 @@ async def request( headers=None, max_allowed_time=None, timeout=_DEFAULT_TIMEOUT, - **kwargs + auto_decompress=False, + **kwargs, ): """Implementation of Authorized Session aiohttp request. @@ -230,8 +286,17 @@ async def request( transmitted. The timout error will be raised after such request completes. """ - - async with aiohttp.ClientSession() as self._auth_request_session: + # Headers come in as bytes which isn't expected behavior, the resumable + # media libraries in some cases expect a str type for the header values, + # but sometimes the operations return these in bytes types. + if headers: + for key in headers.keys(): + if type(headers[key]) is bytes: + headers[key] = headers[key].decode("utf-8") + + async with aiohttp.ClientSession( + auto_decompress=self._auto_decompress + ) as self._auth_request_session: auth_request = Request(self._auth_request_session) self._auth_request = auth_request @@ -264,7 +329,7 @@ async def request( data=data, headers=request_headers, timeout=timeout, - **kwargs + **kwargs, ) remaining_time = guard.remaining_timeout @@ -307,7 +372,7 @@ async def request( max_allowed_time=remaining_time, timeout=timeout, _credential_refresh_attempt=_credential_refresh_attempt + 1, - **kwargs + **kwargs, ) return response diff --git a/google/auth/transport/mtls.py b/google/auth/transport/mtls.py index 5b742306b..b40bfbedf 100644 --- a/google/auth/transport/mtls.py +++ b/google/auth/transport/mtls.py @@ -86,9 +86,12 @@ def default_client_encrypted_cert_source(cert_path, key_path): def callback(): try: - _, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials( - generate_encrypted_key=True - ) + ( + _, + cert_bytes, + key_bytes, + passphrase_bytes, + ) = _mtls_helper.get_client_ssl_credentials(generate_encrypted_key=True) with open(cert_path, "wb") as cert_file: cert_file.write(cert_bytes) with open(key_path, "wb") as key_file: diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index a6cc3b292..4817ea40e 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -104,7 +104,8 @@ async def _token_endpoint_request(request, token_uri, body): method="POST", url=token_uri, headers=headers, body=body ) - response_body1 = await response.data.read() + # Using data.read() resulted in zlib decompression errors. This may require future investigation. + response_body1 = await response.content() response_body = ( response_body1.decode("utf-8") diff --git a/google/oauth2/credentials_async.py b/google/oauth2/credentials_async.py index 2081a0be2..199b7eb56 100644 --- a/google/oauth2/credentials_async.py +++ b/google/oauth2/credentials_async.py @@ -61,7 +61,12 @@ async def refresh(self, request): "token_uri, client_id, and client_secret." ) - access_token, refresh_token, expiry, grant_response = await _client.refresh_grant( + ( + access_token, + refresh_token, + expiry, + grant_response, + ) = await _client.refresh_grant( request, self._token_uri, self._refresh_token, diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py index c32a183a6..87982807f 100644 --- a/tests_async/oauth2/test__client_async.py +++ b/tests_async/oauth2/test__client_async.py @@ -63,6 +63,7 @@ def make_request(response_data, status=http_client.OK): data = json.dumps(response_data).encode("utf-8") response.data = mock.AsyncMock(spec=["__call__", "read"]) response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + response.content = mock.AsyncMock(spec=["__call__"], return_value=data) request = mock.AsyncMock(spec=["transport.Request"]) request.return_value = response return request